diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml
index e356bd3e5..ad1c128bc 100644
--- a/.github/workflows/desktop-release.yml
+++ b/.github/workflows/desktop-release.yml
@@ -144,6 +144,11 @@ jobs:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
+ # TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed.
+ # Surfaces the exact codesign / notarize commands electron-builder spawns,
+ # so we can see which subprocess hangs.
+ DEBUG: electron-builder,electron-osx-sign*,@electron/notarize*
+ ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true"
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
# TrustedSigning PowerShell module. Only populated when signing is enabled.
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
diff --git a/.github/workflows/notary-status.yml b/.github/workflows/notary-status.yml
new file mode 100644
index 000000000..5c7c42038
--- /dev/null
+++ b/.github/workflows/notary-status.yml
@@ -0,0 +1,60 @@
+name: Notary status check
+
+# One-off diagnostic workflow. Queries Apple's notary service to see if your
+# submissions are queued, in progress, accepted, or rejected. Useful when a
+# notarization seems "hung" — most often the queue itself, especially on a
+# brand-new Apple Developer account.
+#
+# Run via: Actions tab -> "Notary status check" -> Run workflow.
+# Inputs are optional; if you provide a submission ID, it also fetches that
+# submission's full Apple log.
+#
+# Safe to delete after diagnosis.
+
+on:
+ workflow_dispatch:
+ inputs:
+ submission_id:
+ description: 'Optional: submission UUID to fetch full Apple log for'
+ required: false
+ default: ''
+
+jobs:
+ status:
+ runs-on: macos-latest
+ steps:
+ - name: List recent notarization submissions
+ env:
+ APPLE_ID: ${{ secrets.APPLE_ID }}
+ APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
+ APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
+ run: |
+ set -euo pipefail
+ echo "::group::Submission history (most recent first)"
+ xcrun notarytool history \
+ --apple-id "$APPLE_ID" \
+ --password "$APPLE_APP_SPECIFIC_PASSWORD" \
+ --team-id "$APPLE_TEAM_ID"
+ echo "::endgroup::"
+
+ - name: Inspect specific submission (if id provided)
+ if: ${{ inputs.submission_id != '' }}
+ env:
+ APPLE_ID: ${{ secrets.APPLE_ID }}
+ APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
+ APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
+ SUBMISSION_ID: ${{ inputs.submission_id }}
+ run: |
+ set -euo pipefail
+ echo "::group::Submission info"
+ xcrun notarytool info "$SUBMISSION_ID" \
+ --apple-id "$APPLE_ID" \
+ --password "$APPLE_APP_SPECIFIC_PASSWORD" \
+ --team-id "$APPLE_TEAM_ID"
+ echo "::endgroup::"
+ echo "::group::Apple's processing log for this submission"
+ xcrun notarytool log "$SUBMISSION_ID" \
+ --apple-id "$APPLE_ID" \
+ --password "$APPLE_APP_SPECIFIC_PASSWORD" \
+ --team-id "$APPLE_TEAM_ID" || true
+ echo "::endgroup::"
diff --git a/.vscode/launch.json b/.vscode/launch.json
index 029e7c647..ad8f8f2a7 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -26,7 +26,16 @@
"pythonArgs": [
"run",
"python"
- ]
+ ],
+ // Mute LangGraph/Pydantic checkpoint serializer warnings
+ // (UserWarnings emitted from pydantic/main.py when the
+ // runtime snapshots a SurfSenseContextSchema into a field
+ // typed `None`) so the debugger's "Raised Exceptions"
+ // breakpoint doesn't pause on a known-harmless event.
+ // Production logs are unaffected.
+ "env": {
+ "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
+ }
},
{
"name": "Backend: FastAPI (No Reload)",
@@ -40,7 +49,10 @@
"pythonArgs": [
"run",
"python"
- ]
+ ],
+ "env": {
+ "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
+ }
},
{
"name": "Backend: FastAPI (main.py)",
@@ -54,7 +66,10 @@
"pythonArgs": [
"run",
"python"
- ]
+ ],
+ "env": {
+ "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
+ }
},
{
"name": "Frontend: Next.js",
@@ -104,7 +119,10 @@
"pythonArgs": [
"run",
"python"
- ]
+ ],
+ "env": {
+ "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
+ }
},
{
"name": "Celery: Beat Scheduler",
@@ -124,7 +142,10 @@
"pythonArgs": [
"run",
"python"
- ]
+ ],
+ "env": {
+ "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
+ }
}
],
"compounds": [
diff --git a/VERSION b/VERSION
index 44517d518..236c7ad08 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.0.19
+0.0.21
diff --git a/docker/.env.example b/docker/.env.example
index 95de0cf85..fd56bdccc 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
# STRIPE_RECONCILIATION_BATCH_SIZE=100
-# Premium token purchases ($1 per 1M tokens for premium-tier models)
+# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
+# credit; premium turns debit the actual per-call provider cost
+# reported by LiteLLM, so cheap and expensive models bill proportionally)
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
-# STRIPE_TOKENS_PER_UNIT=1000000
+# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
+# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
# ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text)
@@ -305,6 +308,24 @@ STT_SERVICE=local/base
# Advanced (optional)
# ------------------------------------------------------------------------------
+# New-chat agent feature flags
+SURFSENSE_ENABLE_CONTEXT_EDITING=true
+SURFSENSE_ENABLE_COMPACTION_V2=true
+SURFSENSE_ENABLE_RETRY_AFTER=true
+SURFSENSE_ENABLE_MODEL_FALLBACK=false
+SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
+SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
+SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
+SURFSENSE_ENABLE_BUSY_MUTEX=true
+SURFSENSE_ENABLE_SKILLS=true
+SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
+SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
+SURFSENSE_ENABLE_ACTION_LOG=true
+SURFSENSE_ENABLE_REVERT_ROUTE=true
+SURFSENSE_ENABLE_PERMISSION=true
+SURFSENSE_ENABLE_DOOM_LOOP=true
+SURFSENSE_ENABLE_STREAM_PARITY_V2=true
+
# Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m
@@ -315,9 +336,24 @@ STT_SERVICE=local/base
# Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500
-# Premium token quota per registered user (default: 5M)
-# Only applies to models with billing_tier=premium in global_llm_config.yaml
-# PREMIUM_TOKEN_LIMIT=5000000
+# Premium credit quota per registered user, in micro-USD (default: $5).
+# Premium turns are debited at the actual per-call provider cost reported
+# by LiteLLM. Only applies to models with billing_tier=premium.
+# PREMIUM_CREDIT_MICROS_LIMIT=5000000
+# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
+
+# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
+# QUOTA_MAX_RESERVE_MICROS=1000000
+
+# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
+# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
+
+# Per-podcast reservation for the podcast Celery task ($0.20 default).
+# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
+
+# Per-video-presentation reservation for the video Celery task ($1.00 default).
+# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
+# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — public users can chat without an account
# Set TRUE to enable /free pages and anonymous chat API
diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example
index 98d396363..ba89059c8 100644
--- a/surfsense_backend/.env.example
+++ b/surfsense_backend/.env.example
@@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
# Set FALSE to disable new checkout session creation temporarily
STRIPE_PAGE_BUYING_ENABLED=TRUE
-# Premium token purchases via Stripe (for premium-tier model usage)
-# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
+# Premium credit purchases via Stripe (for premium-tier model usage).
+# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
+# (default 1_000_000 = $1.00). Premium turns are billed at the actual
+# per-call provider cost reported by LiteLLM.
STRIPE_TOKEN_BUYING_ENABLED=FALSE
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
-STRIPE_TOKENS_PER_UNIT=1000000
+STRIPE_CREDIT_MICROS_PER_UNIT=1000000
+# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
+# STRIPE_TOKENS_PER_UNIT=1000000
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
@@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
PAGES_LIMIT=500
-# Premium token quota per registered user (default: 3,000,000)
-# Applies only to models with billing_tier=premium in global_llm_config.yaml
-PREMIUM_TOKEN_LIMIT=3000000
+# Premium credit quota per registered user, in micro-USD
+# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
+# actual per-call provider cost reported by LiteLLM, so cheap and expensive
+# models bill proportionally. Applies only to models with
+# billing_tier=premium in global_llm_config.yaml.
+PREMIUM_CREDIT_MICROS_LIMIT=5000000
+# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
+# PREMIUM_TOKEN_LIMIT=5000000
+
+# Safety ceiling on per-call premium reservation, in micro-USD.
+# stream_new_chat estimates an upper-bound cost from the model's
+# litellm-published per-token rates × the config's quota_reserve_tokens
+# and clamps to this value so a misconfigured model can't lock the
+# user's whole balance on one call. Default $1.00.
+QUOTA_MAX_RESERVE_MICROS=1000000
+
+# Per-image reservation (in micro-USD) for the POST /image-generations
+# endpoint. Bypassed for free configs. Default $0.05.
+QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
+
+# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
+# Single envelope covers one transcript-generation LLM call. Default $0.20.
+QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
+
+# Per-video-presentation reservation (in micro-USD) used by the video
+# presentation Celery task. Covers worst-case fan-out of N slide-scene
+# generations + refines. Default $1.00. NOTE: tasks using the override
+# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
+QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — allows public users to chat without an account
# Set TRUE to enable /free pages and anonymous chat API
@@ -297,3 +327,30 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
+
+# -----------------------------------------------------------------------------
+# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
+# -----------------------------------------------------------------------------
+# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
+# on a cold turn) is reused across subsequent turns on the same thread,
+# collapsing it to a microsecond hash lookup. All connector tools acquire
+# their own short-lived DB session per call (Phase 2 refactor) so a cached
+# closure is safe to share across requests. Flip OFF only as a last-resort
+# rollback if you suspect cache-related staleness.
+# SURFSENSE_ENABLE_AGENT_CACHE=true
+
+# Cache capacity (max number of compiled-agent entries kept in memory)
+# and TTL per entry (seconds). Working set is typically one entry per
+# active thread on this replica; tune up for very large deployments.
+# SURFSENSE_AGENT_CACHE_MAXSIZE=256
+# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
+
+# -----------------------------------------------------------------------------
+# Connector discovery TTL cache (Phase 1.4 perf optimization)
+# -----------------------------------------------------------------------------
+# Caches the per-search-space "available connectors" + "available document
+# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
+# ORM event listeners auto-invalidate on connector / document inserts,
+# updates and deletes — the TTL only bounds staleness for bulk-import
+# paths that bypass the ORM. Set to 0 to disable the cache.
+# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30
diff --git a/surfsense_backend/Dockerfile b/surfsense_backend/Dockerfile
index 1222b36b6..73d5819b9 100644
--- a/surfsense_backend/Dockerfile
+++ b/surfsense_backend/Dockerfile
@@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
COPY pyproject.toml .
COPY uv.lock .
-# Install PyTorch based on architecture
-RUN if [ "$(uname -m)" = "x86_64" ]; then \
- pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \
- else \
- pip install --no-cache-dir torch torchvision torchaudio; \
- fi
-
-# Install python dependencies
+# Install all Python dependencies from uv.lock for deterministic builds.
+#
+# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
+# which lets prod silently drift to newer upstream versions on every rebuild
+# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
+# Exporting the lock to requirements.txt and feeding it to `uv pip install`
+# pins every transitive package to the exact version captured in uv.lock.
+#
+# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
+# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
+# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
+# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
+# downloads that the lock-based install immediately replaced. If a specific
+# CUDA version is needed (driver compatibility, etc.), wire it through
+# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
RUN pip install --no-cache-dir uv && \
- uv pip install --system --no-cache-dir -e .
+ uv export --frozen --no-dev --no-hashes --no-emit-project \
+ --format requirements-txt -o /tmp/requirements.txt && \
+ uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
+ rm /tmp/requirements.txt
# Set SSL environment variables dynamically
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
@@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
# Pre-download Docling models
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
-# Install Playwright browsers for web scraping if needed
-RUN pip install playwright && \
- playwright install chromium --with-deps
+# Install Playwright browsers for web scraping (the playwright package itself
+# is already installed via uv.lock above)
+RUN playwright install chromium --with-deps
# Copy source code
COPY . .
+# Install the project itself in editable mode. Dependencies were already
+# installed deterministically from uv.lock above, so --no-deps prevents any
+# re-resolution that could pull newer versions.
+RUN uv pip install --system --no-cache-dir --no-deps -e .
+
# Copy and set permissions for entrypoint script
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py
new file mode 100644
index 000000000..fba621a0c
--- /dev/null
+++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py
@@ -0,0 +1,44 @@
+"""138_add_thread_auto_model_pinning_fields
+
+Revision ID: 138
+Revises: 137
+Create Date: 2026-04-30
+
+Add a single thread-level column to persist the Auto (Fastest) model pin:
+- pinned_llm_config_id: concrete resolved global LLM config id used for this
+ thread. NULL means "no pin; Auto will resolve on next turn".
+
+The column is unindexed: all reads are by new_chat_threads.id (primary key),
+so a secondary index would be dead write amplification.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from alembic import op
+
+revision: str = "138"
+down_revision: str | None = "137"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ op.execute(
+ "ALTER TABLE new_chat_threads "
+ "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
+ )
+
+
+def downgrade() -> None:
+ # Drop any shape the thread row may be carrying. The extra columns and
+ # indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
+ # makes each statement a safe no-op on the lean shape.
+ op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
+ op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
+ op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
+ op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
+ op.execute(
+ "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
+ )
diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py
new file mode 100644
index 000000000..83c96a429
--- /dev/null
+++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py
@@ -0,0 +1,160 @@
+"""add user table to zero_publication with column list
+
+Adds the "user" table to zero_publication with a column-list publication
+so that only the 5 fields driving the live usage meters are replicated
+through WAL -> zero-cache -> browser IndexedDB:
+
+ id, pages_limit, pages_used,
+ premium_tokens_limit, premium_tokens_used
+
+Sensitive columns (hashed_password, email, oauth_account, display_name,
+avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
+included in the publication, so they never enter WAL replication.
+
+Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
+(it is already DEFAULT today since "user" was never in the
+TABLES_WITH_FULL_IDENTITY list of migration 117).
+
+IMPORTANT - before AND after running this migration:
+ 1. Stop zero-cache (it holds replication locks that will deadlock DDL)
+ 2. Run: alembic upgrade head
+ 3. Delete / reset the zero-cache data volume
+ 4. Restart zero-cache (it will do a fresh initial sync)
+
+Revision ID: 139
+Revises: 138
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+
+from alembic import op
+
+revision: str = "139"
+down_revision: str | None = "138"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+PUBLICATION_NAME = "zero_publication"
+
+# Document column list as left by migration 117. Must match exactly.
+DOCUMENT_COLS = [
+ "id",
+ "title",
+ "document_type",
+ "search_space_id",
+ "folder_id",
+ "created_by_id",
+ "status",
+ "created_at",
+ "updated_at",
+]
+
+# Five fields needed by the live usage meters (sidebar Tokens/Pages,
+# Buy Tokens content). Keep this list narrow on purpose: anything added
+# here flows into WAL and IndexedDB for every connected browser.
+USER_COLS = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_tokens_limit",
+ "premium_tokens_used",
+]
+
+
+def _terminate_blocked_pids(conn, table: str) -> None:
+ """Kill backends whose locks on *table* would block our AccessExclusiveLock."""
+ conn.execute(
+ sa.text(
+ "SELECT pg_terminate_backend(l.pid) "
+ "FROM pg_locks l "
+ "JOIN pg_class c ON c.oid = l.relation "
+ "WHERE c.relname = :tbl "
+ " AND l.pid != pg_backend_pid()"
+ ),
+ {"tbl": table},
+ )
+
+
+def _has_zero_version(conn, table: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = '_0_version'"
+ ),
+ {"tbl": table},
+ ).fetchone()
+ is not None
+ )
+
+
+def _build_publication_ddl(
+ documents_has_zero_ver: bool, user_has_zero_ver: bool
+) -> str:
+ doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
+ user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
+ doc_col_list = ", ".join(doc_cols)
+ user_col_list = ", ".join(user_cols)
+ return (
+ f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
+ f"notifications, "
+ f"documents ({doc_col_list}), "
+ f"folders, "
+ f"search_source_connectors, "
+ f"new_chat_messages, "
+ f"chat_comments, "
+ f"chat_session_state, "
+ f'"user" ({user_col_list})'
+ )
+
+
+def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
+ doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
+ doc_col_list = ", ".join(doc_cols)
+ return (
+ f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
+ f"notifications, "
+ f"documents ({doc_col_list}), "
+ f"folders, "
+ f"search_source_connectors, "
+ f"new_chat_messages, "
+ f"chat_comments, "
+ f"chat_session_state"
+ )
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ # asyncpg requires LOCK TABLE inside a transaction block. Alembic already
+ # opened one via context.begin_transaction(), but the driver still errors
+ # unless we use an explicit SAVEPOINT (nested transaction) for this block.
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ # Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
+ # migration 117, so this is already DEFAULT. Re-assert anyway so
+ # the column-list publication stays valid (DEFAULT identity only
+ # requires the PK to be in the column list).
+ conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
+
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+
+ conn.execute(
+ sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
+ )
+
+
+def downgrade() -> None:
+ conn = op.get_bind()
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))
diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
new file mode 100644
index 000000000..64aa699e8
--- /dev/null
+++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
@@ -0,0 +1,291 @@
+"""rename premium token columns to credit micros and add cost_micros to token_usage
+
+Migrates the premium quota system from a flat token counter to a USD-cost
+based credit system, where 1 credit = 1 micro-USD ($0.000001).
+
+Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
+price means every existing value is already correct in the new unit, no
+data transformation needed):
+
+ user.premium_tokens_limit -> premium_credit_micros_limit
+ user.premium_tokens_used -> premium_credit_micros_used
+ user.premium_tokens_reserved -> premium_credit_micros_reserved
+
+ premium_token_purchases.tokens_granted -> credit_micros_granted
+
+New column for cost auditing per turn:
+
+ token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
+
+The "user" table is in zero_publication's column list (added in 139), so
+this migration must drop and recreate the publication with the renamed
+column names, otherwise zero-cache will replicate stale column names and
+the FE Zero schema will fail to bind.
+
+IMPORTANT - before AND after running this migration:
+ 1. Stop zero-cache (it holds replication locks that will deadlock DDL)
+ 2. Run: alembic upgrade head
+ 3. Delete / reset the zero-cache data volume
+ 4. Restart zero-cache (it will do a fresh initial sync)
+
+Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
+"user". Skipping the data-volume reset will leave IndexedDB clients seeing
+column-not-found errors from a stale catalog snapshot.
+
+Revision ID: 140
+Revises: 139
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+
+from alembic import op
+
+revision: str = "140"
+down_revision: str | None = "139"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+PUBLICATION_NAME = "zero_publication"
+
+# Replicates 139's document column list verbatim — must stay in sync.
+DOCUMENT_COLS = [
+ "id",
+ "title",
+ "document_type",
+ "search_space_id",
+ "folder_id",
+ "created_by_id",
+ "status",
+ "created_at",
+ "updated_at",
+]
+
+# Same five live-meter fields as 139, with the renamed column names.
+USER_COLS = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_credit_micros_limit",
+ "premium_credit_micros_used",
+]
+
+
+def _terminate_blocked_pids(conn, table: str) -> None:
+ """Kill backends whose locks on *table* would block our AccessExclusiveLock."""
+ conn.execute(
+ sa.text(
+ "SELECT pg_terminate_backend(l.pid) "
+ "FROM pg_locks l "
+ "JOIN pg_class c ON c.oid = l.relation "
+ "WHERE c.relname = :tbl "
+ " AND l.pid != pg_backend_pid()"
+ ),
+ {"tbl": table},
+ )
+
+
+def _has_zero_version(conn, table: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = '_0_version'"
+ ),
+ {"tbl": table},
+ ).fetchone()
+ is not None
+ )
+
+
+def _column_exists(conn, table: str, column: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = :col"
+ ),
+ {"tbl": table, "col": column},
+ ).fetchone()
+ is not None
+ )
+
+
+def _build_publication_ddl(
+ user_cols: list[str],
+ *,
+ documents_has_zero_ver: bool,
+ user_has_zero_ver: bool,
+) -> str:
+ doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
+ user_col_list_with_meta = user_cols + (
+ ['"_0_version"'] if user_has_zero_ver else []
+ )
+ doc_col_list = ", ".join(doc_cols)
+ user_col_list = ", ".join(user_col_list_with_meta)
+ return (
+ f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
+ f"notifications, "
+ f"documents ({doc_col_list}), "
+ f"folders, "
+ f"search_source_connectors, "
+ f"new_chat_messages, "
+ f"chat_comments, "
+ f"chat_session_state, "
+ f'"user" ({user_col_list})'
+ )
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+
+ # ------------------------------------------------------------------
+ # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
+ # dev environments are safe.
+ # ------------------------------------------------------------------
+ if not _column_exists(conn, "token_usage", "cost_micros"):
+ op.add_column(
+ "token_usage",
+ sa.Column(
+ "cost_micros",
+ sa.BigInteger(),
+ nullable=False,
+ server_default="0",
+ ),
+ )
+
+ # ------------------------------------------------------------------
+ # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
+ # ------------------------------------------------------------------
+ if _column_exists(
+ conn, "premium_token_purchases", "tokens_granted"
+ ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
+ op.alter_column(
+ "premium_token_purchases",
+ "tokens_granted",
+ new_column_name="credit_micros_granted",
+ )
+
+ # ------------------------------------------------------------------
+ # 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
+ #
+ # We must drop the publication first (it references the old column
+ # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
+ # in a transaction block; alembic's outer transaction already holds
+ # one, but a SAVEPOINT keeps the LOCK + DDL atomic.
+ # ------------------------------------------------------------------
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
+ # publications require at least the PK to be in the column list,
+ # which is true for both the old and new shape.
+ conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
+
+ # Drop the publication BEFORE renaming columns, otherwise Postgres
+ # rejects the rename: "cannot drop column ... referenced by
+ # publication".
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ for old, new in (
+ ("premium_tokens_limit", "premium_credit_micros_limit"),
+ ("premium_tokens_used", "premium_credit_micros_used"),
+ ("premium_tokens_reserved", "premium_credit_micros_reserved"),
+ ):
+ if _column_exists(conn, "user", old) and not _column_exists(
+ conn, "user", new
+ ):
+ op.alter_column("user", old, new_column_name=new)
+
+ # Update the server_default on the renamed limit column so newly
+ # inserted users get $5 of credit (== 5_000_000 micros) by
+ # default. Existing rows are unaffected.
+ op.alter_column(
+ "user",
+ "premium_credit_micros_limit",
+ server_default="5000000",
+ )
+
+ # Recreate the publication with the new column names.
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+ conn.execute(
+ sa.text(
+ _build_publication_ddl(
+ USER_COLS,
+ documents_has_zero_ver=documents_has_zero_ver,
+ user_has_zero_ver=user_has_zero_ver,
+ )
+ )
+ )
+
+
+def downgrade() -> None:
+ """Revert the rename and drop ``cost_micros``.
+
+ Mirrors ``upgrade``: drop the publication, rename columns back, drop
+ the new column, recreate the publication with the old column list.
+ Same zero-cache stop/reset runbook applies in reverse.
+ """
+ conn = op.get_bind()
+
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ for new, old in (
+ ("premium_credit_micros_limit", "premium_tokens_limit"),
+ ("premium_credit_micros_used", "premium_tokens_used"),
+ ("premium_credit_micros_reserved", "premium_tokens_reserved"),
+ ):
+ if _column_exists(conn, "user", new) and not _column_exists(
+ conn, "user", old
+ ):
+ op.alter_column("user", new, new_column_name=old)
+
+ op.alter_column(
+ "user",
+ "premium_tokens_limit",
+ server_default="5000000",
+ )
+
+ legacy_user_cols = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_tokens_limit",
+ "premium_tokens_used",
+ ]
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+ conn.execute(
+ sa.text(
+ _build_publication_ddl(
+ legacy_user_cols,
+ documents_has_zero_ver=documents_has_zero_ver,
+ user_has_zero_ver=user_has_zero_ver,
+ )
+ )
+ )
+
+ if _column_exists(
+ conn, "premium_token_purchases", "credit_micros_granted"
+ ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
+ op.alter_column(
+ "premium_token_purchases",
+ "credit_micros_granted",
+ new_column_name="tokens_granted",
+ )
+
+ if _column_exists(conn, "token_usage", "cost_micros"):
+ op.drop_column("token_usage", "cost_micros")
diff --git a/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py b/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py
new file mode 100644
index 000000000..9a27e7ed0
--- /dev/null
+++ b/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py
@@ -0,0 +1,66 @@
+"""141_unique_chat_message_turn_role
+
+Revision ID: 141
+Revises: 140
+Create Date: 2026-05-04
+
+Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)``
+where ``turn_id IS NOT NULL``.
+
+Why
+---
+The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being
+moved to write its own ``new_chat_messages`` rows server-side instead of
+relying on the frontend's later ``POST /threads/{id}/messages`` call. This
+closes the "ghost-thread" abuse vector where authenticated callers got free
+LLM completions while ``new_chat_messages`` stayed empty.
+
+For server-side and legacy frontend writes to coexist we need an idempotency
+key. The natural triple is ``(thread_id, turn_id, role)``: the server issues
+exactly one ``turn_id`` per turn, and a turn produces at most one user
+message and one assistant message. Whichever side wins the race writes the
+row; the loser hits ``IntegrityError`` and recovers gracefully.
+
+Partial — ``WHERE turn_id IS NOT NULL`` — so:
+
+ * Legacy rows that predate the ``turn_id`` column (migration 136) keep
+ co-existing without de-dup.
+ * Clone / snapshot inserts in
+ ``app/services/public_chat_service.py`` that build ``NewChatMessage``
+ without ``turn_id`` are unaffected (multiple snapshot copies of the same
+ user/assistant pair are intentional).
+
+This index coexists with the existing single-column ``ix_new_chat_messages_turn_id``
+from migration 136 — no collision.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+
+from alembic import op
+
+revision: str = "141"
+down_revision: str | None = "140"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+INDEX_NAME = "uq_new_chat_messages_thread_turn_role"
+TABLE_NAME = "new_chat_messages"
+
+
+def upgrade() -> None:
+ op.create_index(
+ INDEX_NAME,
+ TABLE_NAME,
+ ["thread_id", "turn_id", "role"],
+ unique=True,
+ postgresql_where=sa.text("turn_id IS NOT NULL"),
+ )
+
+
+def downgrade() -> None:
+ op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
diff --git a/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py b/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py
new file mode 100644
index 000000000..43b30a756
--- /dev/null
+++ b/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py
@@ -0,0 +1,134 @@
+"""142_token_usage_message_id_unique
+
+Revision ID: 142
+Revises: 141
+Create Date: 2026-05-04
+
+Add a partial unique index on ``token_usage(message_id)`` where
+``message_id IS NOT NULL``.
+
+Why
+---
+Two writers can race on the same assistant turn's ``token_usage`` row:
+
+ * ``finalize_assistant_turn`` (server-side, called from the streaming
+ finally block in ``stream_new_chat`` / ``stream_resume_chat``)
+ * ``append_message``'s recovery branch in
+ ``app/routes/new_chat_routes.py`` (legacy frontend round-trip)
+
+Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a
+micro-second-aligned race could observe "no row" on each side and double
+INSERT, producing duplicate ``token_usage`` rows for the same
+``message_id``.
+
+A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``)
+turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING``
+no-ops for the loser, hard-eliminating the race at the DB level. Partial
+because non-chat usage rows (indexing, image generation, podcasts) keep
+``message_id`` NULL — they're per-event, no de-dup needed.
+
+Pre-flight
+----------
+Today's schema only has a non-unique index on ``message_id`` so a
+duplicate population could already exist from any past race. We:
+
+ * Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``).
+ * If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50)
+ we dedupe by deleting all but the smallest ``id`` per group.
+ * If the count exceeds the threshold we abort with a descriptive
+ error rather than silently mutate prod data — operator must
+ investigate before retrying.
+
+Concurrency
+-----------
+``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid
+stalling production writes during deploy (a regular ``CREATE INDEX``
+holds an ACCESS EXCLUSIVE lock for the duration of the build, which
+would block ``token_usage`` INSERTs for every active streaming chat).
+The trade-off is a slower migration (CONCURRENTLY scans the table
+twice) and the ``CREATE`` statement cannot run inside alembic's default
+transaction wrapper — ``autocommit_block()`` handles that.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+
+from alembic import op
+
+revision: str = "142"
+down_revision: str | None = "141"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+INDEX_NAME = "uq_token_usage_message_id"
+TABLE_NAME = "token_usage"
+
+# Refuse to silently mutate prod data if the duplicate population is
+# unexpectedly large — operator should investigate the upstream cause
+# before retrying. 50 is comfortably above any plausible duplicate
+# count from the existing race window (the race is microseconds wide).
+DUPLICATE_ABORT_THRESHOLD = 50
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+
+ dup_groups = conn.execute(
+ sa.text(
+ "SELECT message_id, COUNT(*) AS n "
+ "FROM token_usage "
+ "WHERE message_id IS NOT NULL "
+ "GROUP BY message_id "
+ "HAVING COUNT(*) > 1"
+ )
+ ).fetchall()
+
+ if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD:
+ raise RuntimeError(
+ f"token_usage has {len(dup_groups)} duplicate message_id groups "
+ f"(threshold={DUPLICATE_ABORT_THRESHOLD}). "
+ "Resolve the duplicates manually before re-running this migration."
+ )
+
+ if dup_groups:
+ # Delete all but the smallest-id row per duplicate group. The
+ # smallest id is by definition the earliest insert, so we keep
+ # the row most likely to reflect the actual stream's first
+ # successful write.
+ conn.execute(
+ sa.text(
+ """
+ DELETE FROM token_usage
+ WHERE id IN (
+ SELECT id FROM (
+ SELECT
+ id,
+ row_number() OVER (
+ PARTITION BY message_id ORDER BY id ASC
+ ) AS rn
+ FROM token_usage
+ WHERE message_id IS NOT NULL
+ ) ranked
+ WHERE rn > 1
+ )
+ """
+ )
+ )
+
+ # CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop
+ # alembic's auto-transaction for this op only.
+ with op.get_context().autocommit_block():
+ op.execute(
+ f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} "
+ f"ON {TABLE_NAME} (message_id) "
+ "WHERE message_id IS NOT NULL"
+ )
+
+
+def downgrade() -> None:
+ with op.get_context().autocommit_block():
+ op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}")
diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py
index d3acba175..7afa30a31 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py
@@ -11,7 +11,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
-from .middleware import build_main_agent_deepagent_middleware
from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions,
)
@@ -20,6 +19,8 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
+from .middleware import build_main_agent_deepagent_middleware
+
def build_compiled_agent_graph_sync(
*,
diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py
index cb387278b..d23dc33a9 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py
@@ -31,8 +31,8 @@ from .propagation import (
from .resume import (
build_resume_command,
fan_out_decisions_to_match,
- hitlrequest_action_count,
get_first_pending_subagent_interrupt,
+ hitlrequest_action_count,
)
logger = logging.getLogger(__name__)
@@ -51,7 +51,9 @@ def build_task_tool_with_parent_config(
)
if task_description is None:
- description = TASK_TOOL_DESCRIPTION.format(available_agents=subagent_description_str)
+ description = TASK_TOOL_DESCRIPTION.format(
+ available_agents=subagent_description_str
+ )
elif "{available_agents}" in task_description:
description = task_description.format(available_agents=subagent_description_str)
else:
@@ -90,11 +92,11 @@ def build_task_tool_with_parent_config(
def task(
description: Annotated[
str,
- "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
+ "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
],
subagent_type: Annotated[
str,
- "The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501
+ "The type of subagent to use. Must be one of the available agent types listed in the tool description.",
],
runtime: ToolRuntime,
) -> str | Command:
@@ -119,7 +121,9 @@ def build_task_tool_with_parent_config(
if callable(get_state):
try:
snapshot = get_state(sub_config)
- pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
+ pending_id, pending_value = get_first_pending_subagent_interrupt(
+ snapshot
+ )
except Exception:
# Fail loud if a resume is queued: silent fallback would
# replay the original interrupt to the user.
@@ -158,11 +162,11 @@ def build_task_tool_with_parent_config(
async def atask(
description: Annotated[
str,
- "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
+ "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
],
subagent_type: Annotated[
str,
- "The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501
+ "The type of subagent to use. Must be one of the available agent types listed in the tool description.",
],
runtime: ToolRuntime,
) -> str | Command:
@@ -186,7 +190,9 @@ def build_task_tool_with_parent_config(
if callable(aget_state):
try:
snapshot = await aget_state(sub_config)
- pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
+ pending_id, pending_value = get_first_pending_subagent_interrupt(
+ snapshot
+ )
except Exception:
if has_surfsense_resume(runtime):
logger.exception(
diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py
index 1eae3a519..74e47cfab 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py
@@ -23,7 +23,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
-from ...context_prune.prune_tool_names import safe_exclude_tools
from app.agents.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
@@ -66,6 +65,7 @@ from app.agents.new_chat.plugin_loader import (
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from app.db import ChatVisibility
+from ...context_prune.prune_tool_names import safe_exclude_tools
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py
index 6dd3eb721..6a6fd39b7 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py
@@ -14,8 +14,10 @@ from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
-from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
-from ..tools import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED
+from app.agents.multi_agent_chat.subagents import (
+ get_subagents_to_exclude,
+ main_prompt_registry_subagent_lines,
+)
from app.agents.multi_agent_chat.subagents.mcp_tools.index import (
load_mcp_tools_by_connector,
)
@@ -24,17 +26,19 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
-from app.agents.multi_agent_chat.subagents import (
- get_subagents_to_exclude,
- main_prompt_registry_subagent_lines,
-)
-from ..system_prompt import build_main_agent_system_prompt
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
+from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
+from ..system_prompt import build_main_agent_system_prompt
+from ..tools import (
+ MAIN_AGENT_SURFSENSE_TOOL_NAMES,
+ MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
+)
+
_perf_log = get_perf_logger()
diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py
index 914257521..80e86e5c8 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py
@@ -2,6 +2,9 @@
from __future__ import annotations
-from .index import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED
+from .index import (
+ MAIN_AGENT_SURFSENSE_TOOL_NAMES,
+ MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
+)
__all__ = ["MAIN_AGENT_SURFSENSE_TOOL_NAMES", "MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED"]
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py
index c2ebc2029..938e73bd4 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py
@@ -13,7 +13,9 @@ from .resume import create_generate_resume_tool
from .video_presentation import create_generate_video_presentation_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs}
podcast = create_generate_podcast_tool(
search_space_id=resolved_dependencies["search_space_id"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py
index 4ff02856f..6c65b2cee 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py
@@ -10,7 +10,9 @@ from app.db import ChatVisibility
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs}
if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE:
mem = create_update_team_memory_tool(
@@ -18,7 +20,10 @@ def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) ->
db_session=resolved_dependencies["db_session"],
llm=resolved_dependencies.get("llm"),
)
- return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []}
+ return {
+ "allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}],
+ "ask": [],
+ }
mem = create_update_memory_tool(
user_id=resolved_dependencies["user_id"],
db_session=resolved_dependencies["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py
index 350dab563..3546d4d01 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py
@@ -11,14 +11,20 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
from .web_search import create_web_search_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs}
web = create_web_search_tool(
search_space_id=resolved_dependencies.get("search_space_id"),
available_connectors=resolved_dependencies.get("available_connectors"),
)
- scrape = create_scrape_webpage_tool(firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key"))
- docs = create_search_surfsense_docs_tool(db_session=resolved_dependencies["db_session"])
+ scrape = create_scrape_webpage_tool(
+ firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key")
+ )
+ docs = create_search_surfsense_docs_tool(
+ db_session=resolved_dependencies["db_session"]
+ )
return {
"allow": [
{"name": getattr(web, "name", "") or "", "tool": web},
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py
index 9bbfdccb9..08b0e005e 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py
@@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
)
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
_ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []}
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py
index 57d8e277d..2538a494b 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py
@@ -12,7 +12,9 @@ from .search_events import create_search_calendar_events_tool
from .update_event import create_update_calendar_event_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs}
session_dependencies = {
"db_session": resolved_dependencies["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py
index 9bbfdccb9..08b0e005e 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py
@@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
)
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
_ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []}
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py
index 561ea44ab..28c4ee6ee 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py
@@ -11,7 +11,9 @@ from .delete_page import create_delete_confluence_page_tool
from .update_page import create_update_confluence_page_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs}
session_dependencies = {
"db_session": resolved_dependencies["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py
index 04db7cda6..c0a3bf3c9 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py
@@ -11,7 +11,9 @@ from .read_messages import create_read_discord_messages_tool
from .send_message import create_send_discord_message_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py
index a25755a8d..5864ae972 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py
@@ -10,7 +10,9 @@ from .create_file import create_create_dropbox_file_tool
from .trash_file import create_delete_dropbox_file_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py
index c355536e8..09082d091 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py
@@ -14,7 +14,9 @@ from .trash_email import create_trash_gmail_email_tool
from .update_draft import create_update_gmail_draft_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py
index 85f414d14..7dbee87a0 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py
@@ -10,7 +10,9 @@ from .create_file import create_create_google_drive_file_tool
from .trash_file import create_delete_google_drive_file_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py
index 9d32c320a..342f120be 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py
@@ -11,7 +11,9 @@ from .delete_issue import create_delete_jira_issue_tool
from .update_issue import create_update_jira_issue_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py
index dc147055e..f1ee49964 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py
@@ -11,7 +11,9 @@ from .delete_issue import create_delete_linear_issue_tool
from .update_issue import create_update_linear_issue_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py
index 053a905a3..47b303295 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py
@@ -11,7 +11,9 @@ from .list_events import create_list_luma_events_tool
from .read_event import create_read_luma_event_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py
index 0323781e5..c78f630a1 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py
@@ -11,7 +11,9 @@ from .delete_page import create_delete_notion_page_tool
from .update_page import create_update_notion_page_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py
index 5f40ba704..9a2dadd36 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py
@@ -10,7 +10,9 @@ from .create_file import create_create_onedrive_file_tool
from .trash_file import create_delete_onedrive_file_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py
index 9bbfdccb9..08b0e005e 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py
@@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
)
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
_ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []}
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py
index 4bc481307..cbe76b040 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py
@@ -11,7 +11,9 @@ from .read_messages import create_read_teams_messages_tool
from .send_message import create_send_teams_message_tool
-def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
+def load_tools(
+ *, dependencies: dict[str, Any] | None = None, **kwargs: Any
+) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py
index 68c6ce995..79ab3db10 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py
@@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
## Helper functions for fetching connector metadata maps
+
async def fetch_mcp_connector_metadata_maps(
session: AsyncSession,
search_space_id: int,
@@ -58,6 +59,7 @@ async def fetch_mcp_connector_metadata_maps(
## Helper functions for partitioning tools by connector agent
+
def partition_mcp_tools_by_connector(
tools: Sequence[BaseTool],
connector_id_to_type: dict[int, str],
@@ -104,8 +106,10 @@ def partition_mcp_tools_by_connector(
return dict(buckets)
+
## Helper functions for splitting tools by permissions
+
def _get_mcp_tool_name(tool: BaseTool) -> str:
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
orig = meta.get("mcp_original_tool_name")
@@ -139,6 +143,7 @@ def _split_tools_by_permissions(
## Main function to load MCP tools and split them by permissions for each connector agent
+
async def load_mcp_tools_by_connector(
session: AsyncSession,
search_space_id: int,
@@ -148,9 +153,7 @@ async def load_mcp_tools_by_connector(
Pass ``bypass_internal_hitl=True`` so the subagent's
``HumanInTheLoopMiddleware`` is the single HITL gate.
"""
- flat = await load_mcp_tools(
- session, search_space_id, bypass_internal_hitl=True
- )
+ flat = await load_mcp_tools(session, search_space_id, bypass_internal_hitl=True)
id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id)
buckets = partition_mcp_tools_by_connector(flat, id_map, name_map)
return {
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py
index 51906858a..1b7a19ad7 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py
@@ -8,6 +8,9 @@ from typing import Any, Protocol
from deepagents import SubAgent
from langchain_core.language_models import BaseChatModel
+from app.agents.multi_agent_chat.constants import (
+ SUBAGENT_TO_REQUIRED_CONNECTOR_MAP,
+)
from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import (
build_subagent as build_deliverables_subagent,
)
@@ -62,9 +65,6 @@ from app.agents.multi_agent_chat.subagents.connectors.slack.agent import (
from app.agents.multi_agent_chat.subagents.connectors.teams.agent import (
build_subagent as build_teams_subagent,
)
-from app.agents.multi_agent_chat.constants import (
- SUBAGENT_TO_REQUIRED_CONNECTOR_MAP,
-)
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
read_md_file,
)
@@ -105,6 +105,7 @@ SUBAGENT_BUILDERS_BY_NAME: dict[str, SubagentBuilder] = {
"teams": build_teams_subagent,
}
+
def _route_resource_package(builder: SubagentBuilder) -> str:
mod = builder.__module__
return mod[: -len(".agent")] if mod.endswith(".agent") else mod.rsplit(".", 1)[0]
diff --git a/surfsense_backend/app/agents/new_chat/agent_cache.py b/surfsense_backend/app/agents/new_chat/agent_cache.py
new file mode 100644
index 000000000..fa8e6fb72
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/agent_cache.py
@@ -0,0 +1,357 @@
+"""TTL-LRU cache for compiled SurfSense deep agents.
+
+Why this exists
+---------------
+
+``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
+turn:
+
+1. Discover connectors & document types from Postgres (~50-200ms)
+2. Build the tool list (built-in + MCP) (~200ms-1.7s)
+3. Compose the system prompt
+4. Construct ~15 middleware instances (CPU)
+5. Eagerly compile the general-purpose subagent
+ (``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
+ which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure
+ CPU work)
+6. Compile the outer LangGraph
+
+For a single thread, all six steps produce the SAME object on every turn
+unless the user has changed their LLM config, toggled a feature flag,
+added a connector, etc. The right answer is to compile ONCE per
+"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
+every subsequent turn on the same thread.
+
+Why a per-thread key (not a global pool)
+----------------------------------------
+
+Most middleware in the SurfSense stack captures per-thread state in
+``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
+``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
+would silently leak state across users and threads. Keying the cache on
+``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
+turns on the same thread without changing any middleware's behavior.
+
+Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
+(read via ``runtime.context``) so the cache can collapse to a single
+``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
+then, per-thread keying is the only safe option.
+
+Cache shape
+-----------
+
+* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
+ minutes — matches a typical chat session). ``maxsize`` (default 256)
+ caps memory; LRU evicts least-recently-used on overflow.
+* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
+ cold misses on the same key wait for the first build instead of
+ building N times.
+* Process-local: this is an in-memory cache. Multi-replica deployments
+ pay the build cost once per replica per key. That's fine; the working
+ set per replica is small (one entry per active thread on that replica).
+
+Telemetry
+---------
+
+Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
+
+ * ``hit`` — cache hit, microseconds-fast
+ * ``miss`` — first build for this key, includes build duration
+ * ``stale`` — entry was found but expired; rebuilt
+ * ``evict`` — LRU eviction (size-limited)
+ * ``size`` — current cache occupancy at lookup time
+"""
+
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import logging
+import os
+import time
+from collections import OrderedDict
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass
+from typing import Any
+
+from app.utils.perf import get_perf_logger
+
+logger = logging.getLogger(__name__)
+_perf_log = get_perf_logger()
+
+
+# ---------------------------------------------------------------------------
+# Public API: signature helpers (cache key components)
+# ---------------------------------------------------------------------------
+
+
+def stable_hash(*parts: Any) -> str:
+ """Compute a deterministic SHA1 of the str repr of ``parts``.
+
+ Used for cache key components that need a fixed-width representation
+ (system prompt, tool list, etc.). SHA1 is fine here — this is not a
+ security boundary, just a content fingerprint.
+ """
+ h = hashlib.sha1(usedforsecurity=False)
+ for p in parts:
+ h.update(repr(p).encode("utf-8", errors="replace"))
+ h.update(b"\x1f") # ASCII unit separator between parts
+ return h.hexdigest()
+
+
+def tools_signature(
+ tools: list[Any] | tuple[Any, ...],
+ *,
+ available_connectors: list[str] | None,
+ available_document_types: list[str] | None,
+) -> str:
+ """Hash the bound-tool surface for cache-key purposes.
+
+ The signature changes whenever:
+
+ * A tool is added or removed from the bound list (built-in toggles,
+ MCP tools loaded for the user changes, gating rules flip, etc.).
+ * The available connectors / document types for the search space
+ change (new connector added, last connector removed, new document
+ type indexed). Because :func:`get_connector_gated_tools` derives
+ ``modified_disabled_tools`` from ``available_connectors``, the
+ tool surface is technically already covered — but we hash the
+ connector list separately so an empty-list "no tools changed"
+ situation still rotates the key when, say, the user re-adds a
+ connector that gates a tool we were already not exposing.
+
+ Stays stable across:
+
+ * Process restarts (tool names + descriptions are static).
+ * Different replicas (everyone gets the same hash for the same
+ inputs).
+ """
+ tool_descriptors = sorted(
+ (getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
+ )
+ connectors = sorted(available_connectors or [])
+ doc_types = sorted(available_document_types or [])
+ return stable_hash(tool_descriptors, connectors, doc_types)
+
+
+def flags_signature(flags: Any) -> str:
+ """Hash the resolved :class:`AgentFeatureFlags` dataclass.
+
+ Frozen dataclasses are deterministically reprable, so a SHA1 of their
+ repr is a stable fingerprint. Restart safe (flags are read once at
+ process boot).
+ """
+ return stable_hash(repr(flags))
+
+
+def system_prompt_hash(system_prompt: str) -> str:
+ """Hash a system prompt string. Cheap, ~30µs for typical prompts."""
+ return hashlib.sha1(
+ system_prompt.encode("utf-8", errors="replace"),
+ usedforsecurity=False,
+ ).hexdigest()
+
+
+# ---------------------------------------------------------------------------
+# Cache implementation
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _Entry:
+ value: Any
+ created_at: float
+ last_used_at: float
+
+
+class _AgentCache:
+ """In-process TTL-LRU cache with per-key in-flight de-duplication.
+
+ NOT THREAD-SAFE in the multithreading sense — designed for a single
+ asyncio event loop. Uvicorn runs one event loop per worker process,
+ so this is fine; multi-worker deployments simply each maintain their
+ own cache.
+ """
+
+ def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
+ self._maxsize = maxsize
+ self._ttl = ttl_seconds
+ self._entries: OrderedDict[str, _Entry] = OrderedDict()
+ # One lock per key — guards "build" so concurrent cold misses on
+ # the same key wait for the first build instead of all racing.
+ self._locks: dict[str, asyncio.Lock] = {}
+
+ def _now(self) -> float:
+ return time.monotonic()
+
+ def _is_fresh(self, entry: _Entry) -> bool:
+ return (self._now() - entry.created_at) < self._ttl
+
+ def _evict_if_full(self) -> None:
+ while len(self._entries) >= self._maxsize:
+ evicted_key, _ = self._entries.popitem(last=False)
+ self._locks.pop(evicted_key, None)
+ _perf_log.info(
+ "[agent_cache] evict key=%s reason=lru size=%d",
+ _short(evicted_key),
+ len(self._entries),
+ )
+
+ def _touch(self, key: str, entry: _Entry) -> None:
+ entry.last_used_at = self._now()
+ self._entries.move_to_end(key, last=True)
+
+ async def get_or_build(
+ self,
+ key: str,
+ *,
+ builder: Callable[[], Awaitable[Any]],
+ ) -> Any:
+ """Return the cached value for ``key`` or call ``builder()`` to make it.
+
+ ``builder`` MUST be idempotent — concurrent cold misses on the
+ same key collapse to a single ``builder()`` call (the others
+ wait on the in-flight lock and observe the populated entry on
+ wake).
+ """
+ # Fast path: hot hit.
+ entry = self._entries.get(key)
+ if entry is not None and self._is_fresh(entry):
+ self._touch(key, entry)
+ _perf_log.info(
+ "[agent_cache] hit key=%s age=%.1fs size=%d",
+ _short(key),
+ self._now() - entry.created_at,
+ len(self._entries),
+ )
+ return entry.value
+
+ # Stale entry — drop it; rebuild below.
+ if entry is not None and not self._is_fresh(entry):
+ _perf_log.info(
+ "[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
+ _short(key),
+ self._now() - entry.created_at,
+ self._ttl,
+ )
+ self._entries.pop(key, None)
+
+ # Slow path: serialize concurrent misses for the same key.
+ lock = self._locks.setdefault(key, asyncio.Lock())
+ async with lock:
+ # Double-check after acquiring the lock — another waiter may
+ # have populated the entry while we slept.
+ entry = self._entries.get(key)
+ if entry is not None and self._is_fresh(entry):
+ self._touch(key, entry)
+ _perf_log.info(
+ "[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
+ _short(key),
+ self._now() - entry.created_at,
+ len(self._entries),
+ )
+ return entry.value
+
+ t0 = time.perf_counter()
+ try:
+ value = await builder()
+ except BaseException:
+ # Don't cache failed builds; let the next caller retry.
+ _perf_log.warning(
+ "[agent_cache] build_failed key=%s elapsed=%.3fs",
+ _short(key),
+ time.perf_counter() - t0,
+ )
+ raise
+ elapsed = time.perf_counter() - t0
+
+ # Insert + evict.
+ self._evict_if_full()
+ now = self._now()
+ self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
+ self._entries.move_to_end(key, last=True)
+ _perf_log.info(
+ "[agent_cache] miss key=%s build=%.3fs size=%d",
+ _short(key),
+ elapsed,
+ len(self._entries),
+ )
+ return value
+
+ def invalidate(self, key: str) -> bool:
+ """Drop a single entry; return True if anything was removed."""
+ removed = self._entries.pop(key, None) is not None
+ self._locks.pop(key, None)
+ if removed:
+ _perf_log.info(
+ "[agent_cache] invalidate key=%s size=%d",
+ _short(key),
+ len(self._entries),
+ )
+ return removed
+
+ def invalidate_prefix(self, prefix: str) -> int:
+ """Drop every entry whose key starts with ``prefix``. Returns count."""
+ keys = [k for k in self._entries if k.startswith(prefix)]
+ for k in keys:
+ self._entries.pop(k, None)
+ self._locks.pop(k, None)
+ if keys:
+ _perf_log.info(
+ "[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
+ _short(prefix),
+ len(keys),
+ len(self._entries),
+ )
+ return len(keys)
+
+ def clear(self) -> None:
+ n = len(self._entries)
+ self._entries.clear()
+ self._locks.clear()
+ if n:
+ _perf_log.info("[agent_cache] clear removed=%d", n)
+
+ def stats(self) -> dict[str, Any]:
+ return {
+ "size": len(self._entries),
+ "maxsize": self._maxsize,
+ "ttl_seconds": self._ttl,
+ }
+
+
+def _short(key: str, n: int = 16) -> str:
+ """Truncate keys for log lines so they don't blow up log volume."""
+ return key if len(key) <= n else f"{key[:n]}..."
+
+
+# ---------------------------------------------------------------------------
+# Module-level singleton
+# ---------------------------------------------------------------------------
+
+_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
+_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
+
+_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
+
+
+def get_cache() -> _AgentCache:
+ """Return the process-wide compiled-agent cache singleton."""
+ return _cache
+
+
+def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
+ """Replace the singleton with a fresh cache. Tests only."""
+ global _cache
+ _cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
+ return _cache
+
+
+__all__ = [
+ "flags_signature",
+ "get_cache",
+ "reload_for_tests",
+ "stable_hash",
+ "system_prompt_hash",
+ "tools_signature",
+]
diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py
index b3fe7850b..1f4024d9d 100644
--- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py
+++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py
@@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
subclass of the default ``FilesystemMiddleware`` — while preserving every
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
-summarisation, prompt-caching, etc.).
+summarisation, etc.). Prompt caching is configured at LLM-build time via
+``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
+than as a middleware.
"""
import asyncio
@@ -33,12 +35,18 @@ from langchain.agents.middleware import (
TodoListMiddleware,
ToolCallLimitMiddleware,
)
-from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
+from app.agents.new_chat.agent_cache import (
+ flags_signature,
+ get_cache,
+ stable_hash,
+ system_prompt_hash,
+ tools_signature,
+)
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver
@@ -52,6 +60,7 @@ from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
DoomLoopMiddleware,
FileIntentMiddleware,
+ FlattenSystemMessageMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
@@ -74,6 +83,7 @@ from app.agents.new_chat.plugin_loader import (
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
+from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.agents.new_chat.subagents import build_specialized_subagents
from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
@@ -94,6 +104,39 @@ from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
+
+def _resolve_prompt_model_name(
+ agent_config: AgentConfig | None,
+ llm: BaseChatModel,
+) -> str | None:
+ """Resolve the model id to feed to provider-variant detection.
+
+ Preference order (matches the established idiom in
+ ``llm_router_service.py`` — see ``params.get("base_model") or
+ params.get("model", "")`` usages there):
+
+ 1. ``agent_config.litellm_params["base_model"]`` — required for Azure
+ deployments where ``model_name`` is the deployment slug, not the
+ underlying family. Without this, a deployment named e.g.
+ ``"prod-chat-001"`` would silently miss every provider regex.
+ 2. ``agent_config.model_name`` — the user's configured model id.
+ 3. ``getattr(llm, "model", None)`` — fallback for direct callers that
+ don't supply an ``AgentConfig`` (currently a defensive path; all
+ production callers pass ``agent_config``).
+
+ Returns ``None`` when nothing is available; ``compose_system_prompt``
+ treats that as the ``"default"`` variant (no provider block emitted).
+ """
+ if agent_config is not None:
+ params = agent_config.litellm_params or {}
+ base_model = params.get("base_model")
+ if isinstance(base_model, str) and base_model.strip():
+ return base_model
+ if agent_config.model_name:
+ return agent_config.model_name
+ return getattr(llm, "model", None)
+
+
# =============================================================================
# Connector Type Mapping
# =============================================================================
@@ -279,6 +322,14 @@ async def create_surfsense_deep_agent(
)
"""
_t_agent_total = time.perf_counter()
+
+ # Layer thread-aware prompt caching onto the LLM. Idempotent with the
+ # build-time call in ``llm_config.py``; this run merely adds
+ # ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
+ # configs now that ``thread_id`` is known. No-op when ``thread_id`` is
+ # None or the provider is non-OpenAI-family.
+ apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
+
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(
filesystem_selection,
@@ -287,23 +338,39 @@ async def create_surfsense_deep_agent(
else None,
)
- # Discover available connectors and document types for this search space
+ # Discover available connectors and document types for this search space.
+ #
+ # NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
+ # ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
+ # SQLAlchemy explicitly forbids concurrent operations on the same session
+ # ("This session is provisioning a new connection; concurrent operations
+ # are not permitted on the same session"). The Phase 1.4 in-process TTL
+ # cache in ``connector_service`` already collapses the warm path to a
+ # near-zero pair of dict lookups, so sequential awaits cost nothing in
+ # the common case while remaining correct on cold cache misses.
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
- connector_types = await connector_service.get_available_connectors(
- search_space_id
- )
- if connector_types:
- available_connectors = _map_connectors_to_searchable_types(connector_types)
+ try:
+ connector_types_result = await connector_service.get_available_connectors(
+ search_space_id
+ )
+ if connector_types_result:
+ available_connectors = _map_connectors_to_searchable_types(
+ connector_types_result
+ )
+ except Exception as e:
+ logging.warning("Failed to discover available connectors: %s", e)
- available_document_types = await connector_service.get_available_document_types(
- search_space_id
- )
-
- except Exception as e:
+ try:
+ available_document_types = (
+ await connector_service.get_available_document_types(search_space_id)
+ )
+ except Exception as e:
+ logging.warning("Failed to discover available document types: %s", e)
+ except Exception as e: # pragma: no cover - defensive outer guard
logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
@@ -398,6 +465,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
+ model_name=_resolve_prompt_model_name(agent_config, llm),
)
else:
system_prompt = build_surfsense_system_prompt(
@@ -405,6 +473,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
+ model_name=_resolve_prompt_model_name(agent_config, llm),
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
@@ -424,29 +493,77 @@ async def create_surfsense_deep_agent(
# entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive.
+ #
+ # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
+ # on every per-request value that any middleware in the stack closes
+ # over in ``__init__`` — drop one and you risk leaking state across
+ # threads. Hits collapse this whole block to a microsecond lookup;
+ # misses pay the original CPU cost AND populate the cache.
+ config_id = agent_config.config_id if agent_config is not None else None
+
+ async def _build_agent() -> Any:
+ return await asyncio.to_thread(
+ _build_compiled_agent_blocking,
+ llm=llm,
+ tools=tools,
+ final_system_prompt=final_system_prompt,
+ backend_resolver=backend_resolver,
+ filesystem_mode=filesystem_selection.mode,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ thread_id=thread_id,
+ visibility=visibility,
+ anon_session_id=anon_session_id,
+ available_connectors=available_connectors,
+ available_document_types=available_document_types,
+ # ``mentioned_document_ids`` is consumed by
+ # ``KnowledgePriorityMiddleware`` per turn via
+ # ``runtime.context`` (Phase 1.5). We still pass the
+ # caller-provided list here for the legacy fallback path
+ # (cache disabled / context not propagated) — the middleware
+ # drains its own copy after the first read so a cached graph
+ # never replays stale mentions.
+ mentioned_document_ids=mentioned_document_ids,
+ max_input_tokens=_max_input_tokens,
+ flags=_flags,
+ checkpointer=checkpointer,
+ )
+
_t0 = time.perf_counter()
- agent = await asyncio.to_thread(
- _build_compiled_agent_blocking,
- llm=llm,
- tools=tools,
- final_system_prompt=final_system_prompt,
- backend_resolver=backend_resolver,
- filesystem_mode=filesystem_selection.mode,
- search_space_id=search_space_id,
- user_id=user_id,
- thread_id=thread_id,
- visibility=visibility,
- anon_session_id=anon_session_id,
- available_connectors=available_connectors,
- available_document_types=available_document_types,
- mentioned_document_ids=mentioned_document_ids,
- max_input_tokens=_max_input_tokens,
- flags=_flags,
- checkpointer=checkpointer,
- )
+ if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
+ # Cache key components — order matters only for human readability;
+ # the resulting hash is what's stored. Every component must
+ # rotate on a real shape change AND stay stable across identical
+ # invocations.
+ cache_key = stable_hash(
+ "v1", # schema version of the key — bump if components change
+ config_id,
+ thread_id,
+ user_id,
+ search_space_id,
+ visibility,
+ filesystem_selection.mode,
+ anon_session_id,
+ tools_signature(
+ tools,
+ available_connectors=available_connectors,
+ available_document_types=available_document_types,
+ ),
+ flags_signature(_flags),
+ system_prompt_hash(final_system_prompt),
+ _max_input_tokens,
+ # ``mentioned_document_ids`` deliberately omitted — middleware
+ # reads it from ``runtime.context`` (Phase 1.5).
+ )
+ agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
+ else:
+ agent = await _build_agent()
_perf_log.info(
- "[create_agent] Middleware stack + graph compiled in %.3fs",
+ "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
time.perf_counter() - _t0,
+ "on"
+ if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
+ else "off",
)
_perf_log.info(
@@ -568,7 +685,6 @@ def _build_compiled_agent_blocking(
),
create_surfsense_compaction_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
- AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
@@ -998,6 +1114,14 @@ def _build_compiled_agent_blocking(
noop_mw,
retry_mw,
fallback_mw,
+ # Coalesce a multi-text-block system message into one block
+ # immediately before the model call. Sits innermost on the
+ # system-message-mutation chain so it observes every appender
+ # (todo / filesystem / skills / subagents …) and prevents
+ # OpenRouter→Anthropic from redistributing ``cache_control``
+ # across N blocks and tripping Anthropic's 4-breakpoint cap.
+ # See ``middleware/flatten_system.py`` for full rationale.
+ FlattenSystemMessageMiddleware(),
# Tool-call repair must run after model emits but before
# permission / dedup / doom-loop interpret the calls.
repair_mw,
@@ -1010,12 +1134,12 @@ def _build_compiled_agent_blocking(
action_log_mw,
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
- # Plugin slot — sits just before AnthropicCache so plugin-side
- # transforms see the final tool result and run before any
- # caching heuristics. Multiple plugins in declared order; loader
- # filtered by the admin allowlist already.
+ # Plugin slot — sits at the tail so plugin-side transforms see the
+ # final tool result. Prompt caching is now applied at LLM build time
+ # via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
+ # caching middleware is needed here. Multiple plugins run in declared
+ # order; loader filtered by the admin allowlist already.
*plugin_middlewares,
- AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py
index c1fe45aaa..d720b524b 100644
--- a/surfsense_backend/app/agents/new_chat/context.py
+++ b/surfsense_backend/app/agents/new_chat/context.py
@@ -1,10 +1,25 @@
"""
Context schema definitions for SurfSense agents.
-This module defines the custom state schema used by the SurfSense deep agent.
+This module defines the per-invocation context object passed to the SurfSense
+deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
+
+The agent's compiled graph is the same across invocations (and cached by
+``agent_cache``), so anything that varies per turn — the user mentions a
+specific document, the front-end issues a unique ``request_id``, etc. —
+MUST live on this context object instead of being captured into a
+middleware ``__init__`` closure. Middlewares read fields back via
+``runtime.context.``; tools read them via ``runtime.context``.
+
+This object is read inside both ``KnowledgePriorityMiddleware`` (for
+``mentioned_document_ids``) and any future middleware that needs
+per-request state without invalidating the compiled-agent cache.
"""
-from typing import NotRequired, TypedDict
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import TypedDict
class FileOperationContractState(TypedDict):
@@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
turn_id: str
-class SurfSenseContextSchema(TypedDict):
+@dataclass
+class SurfSenseContextSchema:
"""
- Custom state schema for the SurfSense deep agent.
+ Per-invocation context for the SurfSense deep agent.
- This extends the default agent state with custom fields.
- The default state already includes:
- - messages: Conversation history
- - todos: Task list from TodoListMiddleware
- - files: Virtual filesystem from FilesystemMiddleware
+ Defaults are chosen so the dataclass can be safely default-constructed
+ (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
+ context is supplied — see ``langgraph.runtime.Runtime``). All fields
+ are optional; consumers must None-check before reading.
- We're adding fields needed for knowledge base search:
- - search_space_id: The user's search space ID
- - db_session: Database session (injected at runtime)
- - connector_service: Connector service instance (injected at runtime)
+ Phase 1.5 fields:
+ search_space_id: Search space the request is scoped to.
+ mentioned_document_ids: KB documents the user @-mentioned this turn.
+ Read by ``KnowledgePriorityMiddleware`` to seed its priority
+ list. Stays out of the compiled-agent cache key — that's the
+ whole point of putting it here.
+ file_operation_contract: One-shot file operation contract emitted
+ by ``FileIntentMiddleware`` for the upcoming turn.
+ turn_id / request_id: Correlation IDs surfaced by the streaming
+ task; populated for telemetry.
+
+ Phase 2 will extend with: thread_id, user_id, visibility,
+ filesystem_mode, anon_session_id, available_connectors,
+ available_document_types, created_by_id (everything currently captured
+ by middleware ``__init__`` closures).
"""
- search_space_id: int
- file_operation_contract: NotRequired[FileOperationContractState]
- turn_id: NotRequired[str]
- request_id: NotRequired[str]
- # These are runtime-injected and won't be serialized
- # db_session and connector_service are passed when invoking the agent
+ search_space_id: int | None = None
+ mentioned_document_ids: list[int] = field(default_factory=list)
+ file_operation_contract: FileOperationContractState | None = None
+ turn_id: str | None = None
+ request_id: str | None = None
diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py
index c6524069a..b3dc0fa82 100644
--- a/surfsense_backend/app/agents/new_chat/feature_flags.py
+++ b/surfsense_backend/app/agents/new_chat/feature_flags.py
@@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode,
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
-SurfSense-native). They follow a "default-OFF for risky things,
-default-ON for safe upgrades, master kill-switch for everything new" model.
+SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
+image updates work even when older installs do not have newly introduced
+environment variables. Risky/experimental integrations stay default OFF,
+and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
@@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
Examples
--------
-Local development (recommended for trying everything except doom-loop / selector):
+Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
+ SURFSENSE_ENABLE_MODEL_FALLBACK=false
+ SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
+ SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
- SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
- SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
- SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
- SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
+ SURFSENSE_ENABLE_PERMISSION=true
+ SURFSENSE_ENABLE_DOOM_LOOP=true
+ SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
+ SURFSENSE_ENABLE_STREAM_PARITY_V2=true
Master kill-switch (overrides everything else):
@@ -60,32 +65,28 @@ class AgentFeatureFlags:
disable_new_agent_stack: bool = False
# Agent quality — context budget, retry/limits, name-repair, doom-loop
- enable_context_editing: bool = False
- enable_compaction_v2: bool = False
- enable_retry_after: bool = False
+ enable_context_editing: bool = True
+ enable_compaction_v2: bool = True
+ enable_retry_after: bool = True
enable_model_fallback: bool = False
- enable_model_call_limit: bool = False
- enable_tool_call_limit: bool = False
- enable_tool_call_repair: bool = False
- enable_doom_loop: bool = (
- False # Default OFF until UI handles permission='doom_loop'
- )
+ enable_model_call_limit: bool = True
+ enable_tool_call_limit: bool = True
+ enable_tool_call_repair: bool = True
+ enable_doom_loop: bool = True
# Safety — permissions, concurrency, tool-set narrowing
- enable_permission: bool = False # Default OFF for first deploy
- enable_busy_mutex: bool = False
+ enable_permission: bool = True
+ enable_busy_mutex: bool = True
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
# Skills + subagents
- enable_skills: bool = False
- enable_specialized_subagents: bool = False
- enable_kb_planner_runnable: bool = False
+ enable_skills: bool = True
+ enable_specialized_subagents: bool = True
+ enable_kb_planner_runnable: bool = True
# Snapshot / revert
- enable_action_log: bool = False
- enable_revert_route: bool = (
- False # Backend ships before UI; route returns 503 until this flips
- )
+ enable_action_log: bool = True
+ enable_revert_route: bool = True
# Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
@@ -94,7 +95,7 @@ class AgentFeatureFlags:
# text path and the synthetic ``call_`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible.
- enable_stream_parity_v2: bool = False
+ enable_stream_parity_v2: bool = True
# Plugins
enable_plugin_loader: bool = False
@@ -102,6 +103,41 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False
+ # Performance — compiled-agent cache (Phase 1 + Phase 2).
+ # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
+ # graph if the cache key matches (LLM config + thread + tool surface +
+ # flags + system prompt + filesystem mode). Cuts per-turn agent-build
+ # wall clock from ~4-5s to <50µs on cache hits.
+ #
+ # SAFETY (Phase 2 unblocked this default-on):
+ # All connector mutation tools (``tools/notion``, ``tools/gmail``,
+ # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
+ # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
+ # ``tools/teams``, ``tools/luma``, ``connected_accounts``,
+ # ``update_memory``, ``search_surfsense_docs``) now acquire fresh
+ # short-lived ``AsyncSession`` instances per call via
+ # :data:`async_session_maker`. The factory still accepts ``db_session``
+ # for registry compatibility but ``del``'s it immediately — see any
+ # of those files' factory docstrings for the rationale. The ``llm``
+ # closure is per-(provider, model, config_id) which is already in
+ # the cache key, so the LLM is safe to share across cached hits of
+ # the same key. The KB priority middleware reads
+ # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
+ # not its constructor closure, so the same compiled agent serves
+ # turns with different mention lists correctly.
+ #
+ # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
+ # environment if a regression surfaces. The path is exercised by
+ # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
+ enable_agent_cache: bool = True
+ # Phase 1 (deferred — measure first): pre-build & share the
+ # general-purpose subagent ``CompiledSubAgent`` across cold-cache
+ # misses. Only helps when the outer cache MISSES (cache hits already
+ # reuse the entire SubAgentMiddleware-compiled graph). Off by default
+ # until we have data showing cold misses are frequent enough to
+ # justify the extra global state.
+ enable_agent_cache_share_gp_subagent: bool = False
+
@classmethod
def from_env(cls) -> AgentFeatureFlags:
"""Read flags from environment.
@@ -115,48 +151,76 @@ class AgentFeatureFlags:
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
"middleware is forced OFF for this build."
)
- return cls(disable_new_agent_stack=True)
+ return cls(
+ disable_new_agent_stack=True,
+ enable_context_editing=False,
+ enable_compaction_v2=False,
+ enable_retry_after=False,
+ enable_model_fallback=False,
+ enable_model_call_limit=False,
+ enable_tool_call_limit=False,
+ enable_tool_call_repair=False,
+ enable_doom_loop=False,
+ enable_permission=False,
+ enable_busy_mutex=False,
+ enable_llm_tool_selector=False,
+ enable_skills=False,
+ enable_specialized_subagents=False,
+ enable_kb_planner_runnable=False,
+ enable_action_log=False,
+ enable_revert_route=False,
+ enable_stream_parity_v2=False,
+ enable_plugin_loader=False,
+ enable_otel=False,
+ enable_agent_cache=False,
+ enable_agent_cache_share_gp_subagent=False,
+ )
return cls(
disable_new_agent_stack=False,
# Agent quality
- enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
- enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
- enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
+ enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
+ enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
+ enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool(
- "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
+ "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
),
- enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
+ enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
enable_tool_call_repair=_env_bool(
- "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
+ "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
),
- enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
+ enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
# Safety
- enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
- enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
+ enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
+ enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
),
# Skills + subagents
- enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
+ enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
enable_specialized_subagents=_env_bool(
- "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
+ "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
),
enable_kb_planner_runnable=_env_bool(
- "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
+ "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
),
# Snapshot / revert
- enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
- enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
+ enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
+ enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
# Streaming parity v2
enable_stream_parity_v2=_env_bool(
- "SURFSENSE_ENABLE_STREAM_PARITY_V2", False
+ "SURFSENSE_ENABLE_STREAM_PARITY_V2", True
),
# Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
+ # Performance
+ enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
+ enable_agent_cache_share_gp_subagent=_env_bool(
+ "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
+ ),
)
def any_new_middleware_enabled(self) -> bool:
diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py
index 58d8f84d0..bc37bf1c4 100644
--- a/surfsense_backend/app/agents/new_chat/llm_config.py
+++ b/surfsense_backend/app/agents/new_chat/llm_config.py
@@ -27,6 +27,7 @@ from litellm import get_model_info
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
+from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
@@ -89,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk
-# Provider mapping for LiteLLM model string construction
-PROVIDER_MAP = {
- "OPENAI": "openai",
- "ANTHROPIC": "anthropic",
- "GROQ": "groq",
- "COHERE": "cohere",
- "GOOGLE": "gemini",
- "OLLAMA": "ollama_chat",
- "MISTRAL": "mistral",
- "AZURE_OPENAI": "azure",
- "OPENROUTER": "openrouter",
- "XAI": "xai",
- "BEDROCK": "bedrock",
- "VERTEX_AI": "vertex_ai",
- "TOGETHER_AI": "together_ai",
- "FIREWORKS_AI": "fireworks_ai",
- "DEEPSEEK": "openai",
- "ALIBABA_QWEN": "openai",
- "MOONSHOT": "openai",
- "ZHIPU": "openai",
- "GITHUB_MODELS": "github",
- "REPLICATE": "replicate",
- "PERPLEXITY": "perplexity",
- "ANYSCALE": "anyscale",
- "DEEPINFRA": "deepinfra",
- "CEREBRAS": "cerebras",
- "SAMBANOVA": "sambanova",
- "AI21": "ai21",
- "CLOUDFLARE": "cloudflare",
- "DATABRICKS": "databricks",
- "COMETAPI": "cometapi",
- "HUGGINGFACE": "huggingface",
- "MINIMAX": "openai",
- "CUSTOM": "custom",
-}
+# Provider mapping for LiteLLM model string construction.
+#
+# Single source of truth lives in
+# :mod:`app.services.provider_capabilities` so the YAML loader (which
+# runs during ``app.config`` class-body init) can resolve provider
+# prefixes without dragging the agent / tools tree into module load
+# order. Re-exported here under the historical ``PROVIDER_MAP`` name
+# so existing callers (``llm_router_service``, ``image_gen_router_service``,
+# tests) keep working unchanged.
+from app.services.provider_capabilities import ( # noqa: E402
+ _PROVIDER_PREFIX_MAP as PROVIDER_MAP,
+)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@@ -177,6 +155,17 @@ class AgentConfig:
anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None
+ # Capability flag: best-effort True for the chat selector / catalog.
+ # Resolved via :func:`provider_capabilities.derive_supports_image_input`
+ # which prefers OpenRouter's ``architecture.input_modalities`` and
+ # otherwise consults LiteLLM's authoritative model map. Default True
+ # is the conservative-allow stance — the streaming-task safety net
+ # (``is_known_text_only_chat_model``) is the *only* place a False
+ # actually blocks a request. Setting this to False here without an
+ # authoritative source would silently hide vision-capable models
+ # (the regression we're fixing).
+ supports_image_input: bool = True
+
@classmethod
def from_auto_mode(cls) -> "AgentConfig":
"""
@@ -202,6 +191,12 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
+ # Auto routes across the configured pool, which usually
+ # contains at least one vision-capable deployment; the router
+ # will surface a 404 from a non-vision deployment as a normal
+ # ``allowed_fails`` event and fail over rather than blocking
+ # the request outright.
+ supports_image_input=True,
)
@classmethod
@@ -215,10 +210,24 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
- return cls(
- provider=config.provider.value
+ # Lazy import to avoid pulling provider_capabilities (and its
+ # transitive litellm import) into module-init order.
+ from app.services.provider_capabilities import derive_supports_image_input
+
+ provider_value = (
+ config.provider.value
if hasattr(config.provider, "value")
- else str(config.provider),
+ else str(config.provider)
+ )
+ litellm_params = config.litellm_params or {}
+ base_model = (
+ litellm_params.get("base_model")
+ if isinstance(litellm_params, dict)
+ else None
+ )
+
+ return cls(
+ provider=provider_value,
model_name=config.model_name,
api_key=config.api_key,
api_base=config.api_base,
@@ -234,6 +243,16 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
+ # BYOK rows have no operator-curated capability flag, so we
+ # ask LiteLLM (default-allow on unknown). The streaming
+ # safety net still blocks if the model is *explicitly*
+ # marked text-only.
+ supports_image_input=derive_supports_image_input(
+ provider=provider_value,
+ model_name=config.model_name,
+ base_model=base_model,
+ custom_provider=config.custom_provider,
+ ),
)
@classmethod
@@ -252,15 +271,46 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
+ # Lazy import to avoid pulling provider_capabilities (and its
+ # transitive litellm import) into module-init order.
+ from app.services.provider_capabilities import derive_supports_image_input
+
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "")
+ provider = yaml_config.get("provider", "").upper()
+ model_name = yaml_config.get("model_name", "")
+ custom_provider = yaml_config.get("custom_provider")
+ litellm_params = yaml_config.get("litellm_params") or {}
+ base_model = (
+ litellm_params.get("base_model")
+ if isinstance(litellm_params, dict)
+ else None
+ )
+
+ # Explicit YAML override wins; otherwise derive from LiteLLM /
+ # OpenRouter modalities. The YAML loader already populates this
+ # field, but this method is also called from
+ # ``load_global_llm_config_by_id``'s file fallback (hot reload),
+ # so we re-derive here for safety. The bool() coercion preserves
+ # the loader's behaviour for explicit ``true`` / ``false``
+ # strings that PyYAML may surface.
+ if "supports_image_input" in yaml_config:
+ supports_image_input = bool(yaml_config.get("supports_image_input"))
+ else:
+ supports_image_input = derive_supports_image_input(
+ provider=provider,
+ model_name=model_name,
+ base_model=base_model,
+ custom_provider=custom_provider,
+ )
+
return cls(
- provider=yaml_config.get("provider", "").upper(),
- model_name=yaml_config.get("model_name", ""),
+ provider=provider,
+ model_name=model_name,
api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"),
- custom_provider=yaml_config.get("custom_provider"),
+ custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None,
@@ -275,6 +325,7 @@ class AgentConfig:
is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
+ supports_image_input=supports_image_input,
)
@@ -494,6 +545,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
+ # Configure LiteLLM-native prompt caching (cache_control_injection_points
+ # for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
+ # ``agent_config=None`` here — the YAML path doesn't have provider intent
+ # in a structured form, so we set only the universal injection points.
+ apply_litellm_prompt_caching(llm)
return llm
@@ -518,7 +574,16 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
- return get_auto_mode_llm()
+ router_llm = get_auto_mode_llm()
+ if router_llm is not None:
+ # Universal cache_control_injection_points only — auto-mode
+ # fans out across providers, so OpenAI-only kwargs (e.g.
+ # ``prompt_cache_key``) are left off here. ``drop_params``
+ # would strip them at the provider boundary anyway, but
+ # there's no point setting them when we don't know the
+ # destination.
+ apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
+ return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
@@ -549,4 +614,9 @@ def create_chat_litellm_from_agent_config(
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
+ # Build-time prompt caching: sets ``cache_control_injection_points`` for
+ # all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
+ # Per-thread ``prompt_cache_key`` is layered on later in
+ # ``create_surfsense_deep_agent`` once ``thread_id`` is known.
+ apply_litellm_prompt_caching(llm, agent_config=agent_config)
return llm
diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py
index 094c102f8..6742bd8de 100644
--- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py
+++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py
@@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
+from app.agents.new_chat.middleware.flatten_system import (
+ FlattenSystemMessageMiddleware,
+)
from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
@@ -61,6 +64,7 @@ __all__ = [
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"FileIntentMiddleware",
+ "FlattenSystemMessageMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",
diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
index 4b5ad546d..e7d9b8f75 100644
--- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
+++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
@@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio
import logging
+import time
import weakref
from typing import Any
@@ -58,6 +59,11 @@ class _ThreadLockManager:
weakref.WeakValueDictionary()
)
self._cancel_events: dict[str, asyncio.Event] = {}
+ self._cancel_requested_at_ms: dict[str, int] = {}
+ self._cancel_attempt_count: dict[str, int] = {}
+ # Monotonic per-thread epoch used to prevent stale middleware
+ # teardown from releasing a newer turn's lock.
+ self._turn_epoch: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@@ -76,14 +82,57 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
if event is None:
- return False
+ event = asyncio.Event()
+ self._cancel_events[thread_id] = event
event.set()
+ now_ms = int(time.time() * 1000)
+ self._cancel_requested_at_ms[thread_id] = now_ms
+ self._cancel_attempt_count[thread_id] = (
+ self._cancel_attempt_count.get(thread_id, 0) + 1
+ )
return True
+ def is_cancel_requested(self, thread_id: str) -> bool:
+ event = self._cancel_events.get(thread_id)
+ return bool(event and event.is_set())
+
+ def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
+ if not self.is_cancel_requested(thread_id):
+ return None
+ attempts = self._cancel_attempt_count.get(thread_id, 1)
+ requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
+ return attempts, requested_at_ms
+
def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id)
if event is not None:
event.clear()
+ self._cancel_requested_at_ms.pop(thread_id, None)
+ self._cancel_attempt_count.pop(thread_id, None)
+
+ def bump_turn_epoch(self, thread_id: str) -> int:
+ epoch = self._turn_epoch.get(thread_id, 0) + 1
+ self._turn_epoch[thread_id] = epoch
+ return epoch
+
+ def current_turn_epoch(self, thread_id: str) -> int:
+ return self._turn_epoch.get(thread_id, 0)
+
+ def end_turn(self, thread_id: str) -> None:
+ """Best-effort terminal cleanup for a thread turn.
+
+ This is intentionally idempotent and safe to call from outer stream
+ finally-blocks where middleware teardown might be skipped due to abort
+ or disconnect edge-cases.
+ """
+ # Invalidate any in-flight middleware holder first. This guarantees a
+ # stale ``aafter_agent`` from an older attempt cannot unlock a newer
+ # retry that already acquired the lock for the same thread.
+ self.bump_turn_epoch(thread_id)
+ lock = self._locks.get(thread_id)
+ if lock is not None and lock.locked():
+ lock.release()
+ self.reset(thread_id)
def release(self, thread_id: str) -> bool:
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
@@ -115,18 +164,28 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool:
- """Trip the cancel event for ``thread_id``. Returns True if found."""
+ """Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id)
+def is_cancel_requested(thread_id: str) -> bool:
+ """Return whether ``thread_id`` currently has a pending cancel signal."""
+ return manager.is_cancel_requested(thread_id)
+
+
+def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
+ """Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
+ return manager.cancel_state(thread_id)
+
+
def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id)
-def release_lock(thread_id: str) -> bool:
- """Force-release the per-thread busy lock; safe to call when not held."""
- return manager.release(thread_id)
+def end_turn(thread_id: str) -> None:
+ """Force end-of-turn cleanup for lock + cancel state."""
+ manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
@@ -151,10 +210,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
super().__init__()
self._require_thread_id = require_thread_id
self.tools = []
- # Per-call locks owned by this middleware. We track them as
- # an instance attribute so ``aafter_agent`` knows which lock
- # to release.
- self._held_locks: dict[str, asyncio.Lock] = {}
+ # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
+ # only releases when its epoch still matches the manager's current
+ # epoch for the thread, preventing stale unlock races.
+ self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
@staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
@@ -205,7 +264,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
if lock.locked():
raise BusyError(request_id=thread_id)
await lock.acquire()
- self._held_locks[thread_id] = lock
+ epoch = manager.bump_turn_epoch(thread_id)
+ self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
@@ -219,8 +279,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
thread_id = self._thread_id(runtime)
if thread_id is None:
return None
- lock = self._held_locks.pop(thread_id, None)
- if lock is not None and lock.locked():
+ held = self._held_locks.pop(thread_id, None)
+ if held is None:
+ return None
+ lock, held_epoch = held
+ if held_epoch != manager.current_turn_epoch(thread_id):
+ # Stale teardown from an older attempt (e.g. runtime-recovery path
+ # already advanced epoch). Do not touch current lock/cancel state.
+ return None
+ if lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.
@@ -251,9 +318,11 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [
"BusyMutexMiddleware",
+ "end_turn",
"get_cancel_event",
+ "get_cancel_state",
+ "is_cancel_requested",
"manager",
- "release_lock",
"request_cancel",
"reset_cancel",
]
diff --git a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py
new file mode 100644
index 000000000..29cd57aa0
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py
@@ -0,0 +1,233 @@
+r"""Coalesce multi-block system messages into a single text block.
+
+Several middlewares in our deepagent stack each call
+``append_to_system_message`` on the way down to the model
+(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
+``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
+request reaches the LLM, the system message has 5+ separate text blocks.
+
+Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
+request**, and we configure 2 injection points
+(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
+the prepended ``request.system_message``, this middleware is the
+defensive partner: it guarantees that "the system block" is *one*
+content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
+OpenRouter→Anthropic transformer can never multiply our budget into
+several breakpoints by spreading ``cache_control`` across multiple
+text blocks of a multi-block system content.
+
+Without flattening we used to see::
+
+ OpenrouterException - {"error":{"message":"Provider returned error",
+ "code":400,"metadata":{"raw":"...A maximum of 4 blocks with
+ cache_control may be provided. Found 5."}}}
+
+(Same error class documented in
+https://github.com/BerriAI/litellm/issues/15696 and
+https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
+in PR #15395 covers the litellm transformer but does not protect us
+when the OpenRouter SaaS itself does the redistribution.)
+
+A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
+the first injection point from ``role: system`` to ``index: 0``)
+neutralises the *primary* cause of the same 400 — multiple
+``SystemMessage``\ s injected by ``before_agent`` middlewares
+(priority/tree/memory/file-intent/anonymous-doc) accumulating across
+turns, each tagged with ``cache_control`` by the ``role: system``
+matcher. This middleware remains useful as defence-in-depth against
+the multi-block redistribution path.
+
+Placement: innermost on the system-message-mutation chain, after every
+appender (``todo``/``filesystem``/``skills``/``subagents``) and after
+summarization, but before ``noop``/``retry``/``fallback`` so each retry
+attempt sees a flattened payload. See ``chat_deepagent.py``.
+
+Idempotent: a string-content system message is left untouched. A list
+that contains anything other than plain text blocks (e.g. an image) is
+also left untouched — those are rare on system messages and we'd lose
+the non-text payload by joining.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Awaitable, Callable
+from typing import Any
+
+from langchain.agents.middleware.types import (
+ AgentMiddleware,
+ AgentState,
+ ContextT,
+ ModelRequest,
+ ModelResponse,
+ ResponseT,
+)
+from langchain_core.messages import SystemMessage
+
+logger = logging.getLogger(__name__)
+
+
+def _flatten_text_blocks(content: list[Any]) -> str | None:
+ """Return joined text if every block is a plain ``{"type": "text"}``.
+
+ Returns ``None`` when the list contains anything that isn't a text
+ block we can safely concatenate (image, audio, file, non-standard
+ blocks, dicts with extra non-cache_control fields). The caller
+ leaves the original content untouched in that case rather than
+ silently dropping payload.
+
+ ``cache_control`` on individual blocks is intentionally discarded —
+ the whole point of flattening is to let LiteLLM's
+ ``cache_control_injection_points`` re-place a single breakpoint on
+ the resulting one-block system content.
+ """
+ chunks: list[str] = []
+ for block in content:
+ if isinstance(block, str):
+ chunks.append(block)
+ continue
+ if not isinstance(block, dict):
+ return None
+ if block.get("type") != "text":
+ return None
+ text = block.get("text")
+ if not isinstance(text, str):
+ return None
+ chunks.append(text)
+ return "\n\n".join(chunks)
+
+
+def _flattened_request(
+ request: ModelRequest[ContextT],
+) -> ModelRequest[ContextT] | None:
+ """Return a request with system_message flattened, or ``None`` for no-op."""
+ sys_msg = request.system_message
+ if sys_msg is None:
+ return None
+ content = sys_msg.content
+ if not isinstance(content, list) or len(content) <= 1:
+ return None
+
+ flattened = _flatten_text_blocks(content)
+ if flattened is None:
+ return None
+
+ new_sys = SystemMessage(
+ content=flattened,
+ additional_kwargs=dict(sys_msg.additional_kwargs),
+ response_metadata=dict(sys_msg.response_metadata),
+ )
+ if sys_msg.id is not None:
+ new_sys.id = sys_msg.id
+ return request.override(system_message=new_sys)
+
+
+def _diagnostic_summary(request: ModelRequest[Any]) -> str:
+ """One-line dump of cache_control-relevant request shape.
+
+ Temporary diagnostic to prove where the ``Found N`` cache_control
+ breakpoints are coming from when Anthropic 400s. Removed once the
+ root cause is confirmed and a fix is in place.
+ """
+ sys_msg = request.system_message
+ if sys_msg is None:
+ sys_shape = "none"
+ elif isinstance(sys_msg.content, str):
+ sys_shape = f"str(len={len(sys_msg.content)})"
+ elif isinstance(sys_msg.content, list):
+ sys_shape = f"list(blocks={len(sys_msg.content)})"
+ else:
+ sys_shape = f"other({type(sys_msg.content).__name__})"
+
+ role_hist: list[str] = []
+ multi_block_msgs = 0
+ msgs_with_cc = 0
+ sys_msgs_in_history = 0
+ for m in request.messages:
+ mtype = getattr(m, "type", type(m).__name__)
+ role_hist.append(mtype)
+ if isinstance(m, SystemMessage):
+ sys_msgs_in_history += 1
+ c = getattr(m, "content", None)
+ if isinstance(c, list):
+ multi_block_msgs += 1
+ for blk in c:
+ if isinstance(blk, dict) and "cache_control" in blk:
+ msgs_with_cc += 1
+ break
+ if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
+ msgs_with_cc += 1
+
+ tools = request.tools or []
+ tools_with_cc = 0
+ for t in tools:
+ if isinstance(t, dict) and (
+ "cache_control" in t or "cache_control" in t.get("function", {})
+ ):
+ tools_with_cc += 1
+
+ return (
+ f"sys={sys_shape} msgs={len(request.messages)} "
+ f"sys_msgs_in_history={sys_msgs_in_history} "
+ f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
+ f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
+ f"roles={role_hist[-8:]}"
+ )
+
+
+class FlattenSystemMessageMiddleware(
+ AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
+):
+ """Collapse a multi-text-block system message to a single string.
+
+ Sits innermost on the system-message-mutation chain so it observes
+ every middleware's contribution. Has no other side effect — the
+ body of every block is preserved, just joined with ``"\\n\\n"``.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.tools = []
+
+ def wrap_model_call( # type: ignore[override]
+ self,
+ request: ModelRequest[ContextT],
+ handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
+ ) -> Any:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
+ flattened = _flattened_request(request)
+ if flattened is not None:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "[flatten_system] collapsed %d system blocks to one",
+ len(request.system_message.content), # type: ignore[arg-type, union-attr]
+ )
+ return handler(flattened)
+ return handler(request)
+
+ async def awrap_model_call( # type: ignore[override]
+ self,
+ request: ModelRequest[ContextT],
+ handler: Callable[
+ [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
+ ],
+ ) -> Any:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
+ flattened = _flattened_request(request)
+ if flattened is not None:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "[flatten_system] collapsed %d system blocks to one",
+ len(request.system_message.content), # type: ignore[arg-type, union-attr]
+ )
+ return await handler(flattened)
+ return await handler(request)
+
+
+__all__ = [
+ "FlattenSystemMessageMiddleware",
+ "_flatten_text_blocks",
+ "_flattened_request",
+]
diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py
index 0820e8c3e..ee5c1d182 100644
--- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py
+++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py
@@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
- del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
@@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if anon_doc:
return self._anon_priority(state, anon_doc)
- return await self._authenticated_priority(state, messages, user_text)
+ return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority(
self,
@@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
messages: Sequence[BaseMessage],
user_text: str,
+ runtime: Runtime[Any] | None = None,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
(
@@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text,
)
+ # Per-turn ``mentioned_document_ids`` flow:
+ # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
+ # streaming task supplies a fresh :class:`SurfSenseContextSchema`
+ # on every ``astream_events`` call, so this list is naturally
+ # scoped to the current turn. Allows cross-turn graph reuse via
+ # ``agent_cache``.
+ # 2. Legacy fallback (cache disabled / context not propagated): the
+ # constructor-injected ``self.mentioned_document_ids`` list. We
+ # drain it after the first read so a cached graph (no Phase 1.5
+ # wiring) doesn't keep replaying the same mentions on every
+ # turn.
+ #
+ # CRITICAL: distinguish "context absent" (legacy caller, no field at
+ # all) from "context provided but empty" (turn with no mentions).
+ # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
+ # Python, so a naive ``if ctx_mentions:`` would fall through to the
+ # legacy closure on every no-mention follow-up turn — replaying the
+ # mentions baked in by turn 1's cache-miss build. Always drain the
+ # closure once the runtime path has fired so a cached middleware
+ # instance can never resurrect stale state.
+ mention_ids: list[int] = []
+ ctx = getattr(runtime, "context", None) if runtime is not None else None
+ ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
+ if ctx_mentions is not None:
+ # Runtime path is authoritative — even an empty list means
+ # "this turn has no mentions", NOT "look at the closure".
+ mention_ids = list(ctx_mentions)
+ if self.mentioned_document_ids:
+ self.mentioned_document_ids = []
+ elif self.mentioned_document_ids:
+ mention_ids = list(self.mentioned_document_ids)
+ self.mentioned_document_ids = []
+
mentioned_results: list[dict[str, Any]] = []
- if self.mentioned_document_ids:
+ if mention_ids:
mentioned_results = await fetch_mentioned_documents(
- document_ids=self.mentioned_document_ids,
+ document_ids=mention_ids,
search_space_id=self.search_space_id,
)
- self.mentioned_document_ids = []
if is_recency:
doc_types = _resolve_search_types(
diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py
new file mode 100644
index 000000000..9fe47cdac
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py
@@ -0,0 +1,188 @@
+r"""LiteLLM-native prompt caching configuration for SurfSense agents.
+
+Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
+activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
+gate always failed) with LiteLLM's universal caching mechanism.
+
+Coverage:
+
+- Marker-based providers (need ``cache_control`` injection, which LiteLLM
+ performs automatically when ``cache_control_injection_points`` is set):
+ ``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
+ ``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
+ (Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
+- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
+ ``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
+ tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
+
+We inject **two** breakpoints per request:
+
+- ``index: 0`` — pins the SurfSense system prompt at the head of the
+ request (provider variant, citation rules, tool catalog, KB tree,
+ skills metadata). The langchain agent factory always prepends
+ ``request.system_message`` at index 0 (see ``factory.py``
+ ``_execute_model_async``), so this targets exactly the main system
+ prompt regardless of how many other ``SystemMessage``\ s the
+ ``before_agent`` injectors (priority, tree, memory, file-intent,
+ anonymous-doc) have inserted into ``state["messages"]``. Using
+ ``role: system`` here would apply ``cache_control`` to **every**
+ system-role message and trip Anthropic's hard cap of 4 cache
+ breakpoints per request once the conversation accumulates enough
+ injected system messages — which surfaces as the upstream 400
+ ``A maximum of 4 blocks with cache_control may be provided. Found N``
+ via OpenRouter→Anthropic.
+- ``index: -1`` — pins the latest message so multi-turn savings compound:
+ Anthropic-family providers use longest-matching-prefix lookup, so turn
+ N+1 still reads turn N's cache up to the shared prefix.
+
+For OpenAI-family configs we additionally pass:
+
+- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
+ raises hit rate by sending requests with a shared prefix to the same
+ backend.
+- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
+ 5-10 min in-memory cache.
+
+Safety net: ``litellm.drop_params=True`` is set globally in
+``app.services.llm_service`` at module-load time. Any kwarg the destination
+provider doesn't recognise is auto-stripped at the provider transformer
+layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
+``prompt_cache_key`` etc.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from langchain_core.language_models import BaseChatModel
+
+if TYPE_CHECKING:
+ from app.agents.new_chat.llm_config import AgentConfig
+
+logger = logging.getLogger(__name__)
+
+
+# Two-breakpoint policy: head-of-request + latest message. See module
+# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
+# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
+#
+# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
+# ``before_agent`` middlewares (priority, tree, memory, file-intent,
+# anonymous-doc) insert ``SystemMessage`` instances into
+# ``state["messages"]`` that accumulate across turns. With
+# ``role: system`` the LiteLLM hook would tag *every* one of them with
+# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
+# always targets the langchain-prepended ``request.system_message``
+# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
+# block), giving us exactly one stable cache breakpoint.
+_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
+ {"location": "message", "index": 0},
+ {"location": "message", "index": -1},
+)
+
+# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
+# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
+# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
+# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
+# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
+# MINIMAX), so we can't infer family from the litellm prefix alone.
+_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
+
+
+def _is_router_llm(llm: BaseChatModel) -> bool:
+ """Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
+
+ Importing ``app.services.llm_router_service`` at module-load time would
+ create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
+ Class-name comparison is sufficient since the class is defined in a
+ single place.
+ """
+ return type(llm).__name__ == "ChatLiteLLMRouter"
+
+
+def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
+ """Whether the config targets an OpenAI-style prompt-cache surface.
+
+ Strict — only returns True when the user explicitly chose OPENAI,
+ DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
+ ``YAMLConfig``. Auto-mode and custom providers return False because
+ we can't statically know the destination.
+ """
+ if agent_config is None or not agent_config.provider:
+ return False
+ if agent_config.is_auto_mode:
+ return False
+ if agent_config.custom_provider:
+ return False
+ return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
+
+
+def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
+ """Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
+
+ Initialises the field to ``{}`` when present-but-None on a Pydantic v2
+ model. Returns ``None`` if the LLM type doesn't expose a writable
+ ``model_kwargs`` attribute (caller should treat as no-op).
+ """
+ model_kwargs = getattr(llm, "model_kwargs", None)
+ if isinstance(model_kwargs, dict):
+ return model_kwargs
+ try:
+ llm.model_kwargs = {} # type: ignore[attr-defined]
+ except Exception:
+ return None
+ refreshed = getattr(llm, "model_kwargs", None)
+ return refreshed if isinstance(refreshed, dict) else None
+
+
+def apply_litellm_prompt_caching(
+ llm: BaseChatModel,
+ *,
+ agent_config: AgentConfig | None = None,
+ thread_id: int | None = None,
+) -> None:
+ """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
+
+ Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
+ ``agent_config.litellm_params`` overrides) are preserved. Mutates
+ ``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
+ via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
+ in our custom ``ChatLiteLLMRouter``.
+
+ Args:
+ llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
+ agent_config: Optional ``AgentConfig`` driving provider-specific
+ behaviour. When omitted (or auto-mode), only the universal
+ ``cache_control_injection_points`` are set.
+ thread_id: Optional thread id used to construct a per-thread
+ ``prompt_cache_key`` for OpenAI-family providers. Caching still
+ works without it (server-side automatic), but the key improves
+ backend routing affinity and therefore hit rate.
+ """
+ model_kwargs = _get_or_init_model_kwargs(llm)
+ if model_kwargs is None:
+ logger.debug(
+ "apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
+ type(llm).__name__,
+ )
+ return
+
+ if "cache_control_injection_points" not in model_kwargs:
+ model_kwargs["cache_control_injection_points"] = [
+ dict(point) for point in _DEFAULT_INJECTION_POINTS
+ ]
+
+ # OpenAI-family extras only when we statically know the destination is
+ # OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
+ # so we can't safely set OpenAI-only kwargs there (drop_params would
+ # strip them but it's wasteful to set them in the first place).
+ if _is_router_llm(llm):
+ return
+ if not _is_openai_family_config(agent_config):
+ return
+
+ if thread_id is not None and "prompt_cache_key" not in model_kwargs:
+ model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
+ if "prompt_cache_retention" not in model_kwargs:
+ model_kwargs["prompt_cache_retention"] = "24h"
diff --git a/surfsense_backend/app/agents/new_chat/subagents/constants.py b/surfsense_backend/app/agents/new_chat/subagents/constants.py
index ef5a33e22..cb1da499b 100644
--- a/surfsense_backend/app/agents/new_chat/subagents/constants.py
+++ b/surfsense_backend/app/agents/new_chat/subagents/constants.py
@@ -27,14 +27,9 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset(
{
# Exact tool names from shared deny patterns.
- *{
- name
- for name in WRITE_TOOL_DENY_PATTERNS
- if "*" not in name
- },
+ *{name for name in WRITE_TOOL_DENY_PATTERNS if "*" not in name},
# Additional non-provider state mutation controls.
"write_todos",
"task",
}
)
-
diff --git a/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py b/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py
index 238b13e8e..da332fe28 100644
--- a/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py
+++ b/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py
@@ -112,10 +112,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any:
Rule(permission=name, pattern="*", action="deny")
for name in NON_PROVIDER_STATE_MUTATION_DENY
)
- rules.extend(
- Rule(permission=name, pattern="*", action="ask")
- for name in ask_tools
- )
+ rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools)
return PermissionMiddleware(
rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")]
)
@@ -163,4 +160,3 @@ def build_linear_specialist_subagent(
if model is not None:
spec["model"] = model
return spec # type: ignore[return-value]
-
diff --git a/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py b/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py
index b72edeee8..90ca80152 100644
--- a/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py
+++ b/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py
@@ -119,10 +119,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any:
Rule(permission=name, pattern="*", action="deny")
for name in NON_PROVIDER_STATE_MUTATION_DENY
)
- rules.extend(
- Rule(permission=name, pattern="*", action="ask")
- for name in ask_tools
- )
+ rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools)
return PermissionMiddleware(
rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")]
)
@@ -171,4 +168,3 @@ def build_slack_specialist_subagent(
if model is not None:
spec["model"] = model
return spec # type: ignore[return-value]
-
diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py
index 095413bdb..c56db1528 100644
--- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py
@@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
+from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """
+ Factory function to create the create_confluence_page tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_confluence_page tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_confluence_page(
title: str,
@@ -42,160 +60,163 @@ def create_create_confluence_page_tool(
"""
logger.info(f"create_confluence_page called: title='{title}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
- metadata_service = ConfluenceToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = ConfluenceToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
+ )
- if "error" in context:
- return {"status": "error", "message": context["error"]}
+ if "error" in context:
+ return {"status": "error", "message": context["error"]}
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected Confluence accounts need re-authentication.",
- "connector_type": "confluence",
- }
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected Confluence accounts need re-authentication.",
+ "connector_type": "confluence",
+ }
- result = request_approval(
- action_type="confluence_page_creation",
- tool_name="create_confluence_page",
- params={
- "title": title,
- "content": content,
- "space_id": space_id,
- "connector_id": connector_id,
- },
- context=context,
- )
+ result = request_approval(
+ action_type="confluence_page_creation",
+ tool_name="create_confluence_page",
+ params={
+ "title": title,
+ "content": content,
+ "space_id": space_id,
+ "connector_id": connector_id,
+ },
+ context=context,
+ )
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
- final_title = result.params.get("title", title)
- final_content = result.params.get("content", content) or ""
- final_space_id = result.params.get("space_id", space_id)
- final_connector_id = result.params.get("connector_id", connector_id)
+ final_title = result.params.get("title", title)
+ final_content = result.params.get("content", content) or ""
+ final_space_id = result.params.get("space_id", space_id)
+ final_connector_id = result.params.get("connector_id", connector_id)
- if not final_title or not final_title.strip():
- return {"status": "error", "message": "Page title cannot be empty."}
- if not final_space_id:
- return {"status": "error", "message": "A space must be selected."}
+ if not final_title or not final_title.strip():
+ return {"status": "error", "message": "Page title cannot be empty."}
+ if not final_space_id:
+ return {"status": "error", "message": "A space must be selected."}
- from sqlalchemy.future import select
+ from sqlalchemy.future import select
- from app.db import SearchSourceConnector, SearchSourceConnectorType
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Confluence connector found.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Confluence connector is invalid.",
- }
-
- try:
- client = ConfluenceHistoryConnector(
- session=db_session, connector_id=actual_connector_id
- )
- api_result = await client.create_page(
- space_id=final_space_id,
- title=final_title,
- body=final_content,
- )
- await client.close()
- except Exception as api_err:
- if (
- "http 403" in str(api_err).lower()
- or "status code 403" in str(api_err).lower()
- ):
- try:
- _conn = connector
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- pass
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
-
- page_id = str(api_result.get("id", ""))
- page_links = (
- api_result.get("_links", {}) if isinstance(api_result, dict) else {}
- )
- page_url = ""
- if page_links.get("base") and page_links.get("webui"):
- page_url = f"{page_links['base']}{page_links['webui']}"
-
- kb_message_suffix = ""
- try:
- from app.services.confluence import ConfluenceKBSyncService
-
- kb_service = ConfluenceKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- page_id=page_id,
- page_title=final_title,
- space_id=final_space_id,
- body_content=final_content,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Confluence connector found.",
+ }
+ actual_connector_id = connector.id
else:
- kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Confluence connector is invalid.",
+ }
- return {
- "status": "success",
- "page_id": page_id,
- "page_url": page_url,
- "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
- }
+ try:
+ client = ConfluenceHistoryConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
+ api_result = await client.create_page(
+ space_id=final_space_id,
+ title=final_title,
+ body=final_content,
+ )
+ await client.close()
+ except Exception as api_err:
+ if (
+ "http 403" in str(api_err).lower()
+ or "status code 403" in str(api_err).lower()
+ ):
+ try:
+ _conn = connector
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ page_id = str(api_result.get("id", ""))
+ page_links = (
+ api_result.get("_links", {}) if isinstance(api_result, dict) else {}
+ )
+ page_url = ""
+ if page_links.get("base") and page_links.get("webui"):
+ page_url = f"{page_links['base']}{page_links['webui']}"
+
+ kb_message_suffix = ""
+ try:
+ from app.services.confluence import ConfluenceKBSyncService
+
+ kb_service = ConfluenceKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ page_id=page_id,
+ page_title=final_title,
+ space_id=final_space_id,
+ body_content=final_content,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "page_id": page_id,
+ "page_url": page_url,
+ "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py
index 7c03c2760..d4cd5032f 100644
--- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py
@@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
+from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """
+ Factory function to create the delete_confluence_page tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_confluence_page tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_confluence_page(
page_title_or_id: str,
@@ -43,137 +61,143 @@ def create_delete_confluence_page_tool(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
- metadata_service = ConfluenceToolMetadataService(db_session)
- context = await metadata_service.get_deletion_context(
- search_space_id, user_id, page_title_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "confluence",
- }
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
-
- page_data = context["page"]
- page_id = page_data["page_id"]
- page_title = page_data.get("page_title", "")
- document_id = page_data["document_id"]
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="confluence_page_deletion",
- tool_name="delete_confluence_page",
- params={
- "page_id": page_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = ConfluenceToolMetadataService(db_session)
+ context = await metadata_service.get_deletion_context(
+ search_space_id, user_id, page_title_or_id
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Confluence connector is invalid.",
- }
- try:
- client = ConfluenceHistoryConnector(
- session=db_session, connector_id=final_connector_id
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "confluence",
+ }
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
+
+ page_data = context["page"]
+ page_id = page_data["page_id"]
+ page_title = page_data.get("page_title", "")
+ document_id = page_data["document_id"]
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ result = request_approval(
+ action_type="confluence_page_deletion",
+ tool_name="delete_confluence_page",
+ params={
+ "page_id": page_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- await client.delete_page(final_page_id)
- await client.close()
- except Exception as api_err:
- if (
- "http 403" in str(api_err).lower()
- or "status code 403" in str(api_err).lower()
- ):
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- raise
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
+ final_page_id = result.params.get("page_id", page_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this page.",
+ }
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Confluence connector is invalid.",
+ }
- message = f"Confluence page '{page_title}' deleted successfully."
- if deleted_from_kb:
- message += " Also removed from the knowledge base."
+ try:
+ client = ConfluenceHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ await client.delete_page(final_page_id)
+ await client.close()
+ except Exception as api_err:
+ if (
+ "http 403" in str(api_err).lower()
+ or "status code 403" in str(api_err).lower()
+ ):
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- return {
- "status": "success",
- "page_id": final_page_id,
- "deleted_from_kb": deleted_from_kb,
- "message": message,
- }
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+
+ message = f"Confluence page '{page_title}' deleted successfully."
+ if deleted_from_kb:
+ message += " Also removed from the knowledge base."
+
+ return {
+ "status": "success",
+ "page_id": final_page_id,
+ "deleted_from_kb": deleted_from_kb,
+ "message": message,
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py
index 791d0d8c5..51c205e00 100644
--- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py
@@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
+from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """
+ Factory function to create the update_confluence_page tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured update_confluence_page tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_confluence_page(
page_title_or_id: str,
@@ -45,164 +63,168 @@ def create_update_confluence_page_tool(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
- metadata_service = ConfluenceToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, page_title_or_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = ConfluenceToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, page_title_or_id
+ )
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "confluence",
+ }
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
+
+ page_data = context["page"]
+ page_id = page_data["page_id"]
+ current_title = page_data["page_title"]
+ current_body = page_data.get("body", "")
+ current_version = page_data.get("version", 1)
+ document_id = page_data.get("document_id")
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ result = request_approval(
+ action_type="confluence_page_update",
+ tool_name="update_confluence_page",
+ params={
+ "page_id": page_id,
+ "document_id": document_id,
+ "new_title": new_title,
+ "new_content": new_content,
+ "version": current_version,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
+ )
+
+ if result.rejected:
return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "confluence",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
- page_data = context["page"]
- page_id = page_data["page_id"]
- current_title = page_data["page_title"]
- current_body = page_data.get("body", "")
- current_version = page_data.get("version", 1)
- document_id = page_data.get("document_id")
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="confluence_page_update",
- tool_name="update_confluence_page",
- params={
- "page_id": page_id,
- "document_id": document_id,
- "new_title": new_title,
- "new_content": new_content,
- "version": current_version,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_title = result.params.get("new_title", new_title) or current_title
- final_content = result.params.get("new_content", new_content)
- if final_content is None:
- final_content = current_body
- final_version = result.params.get("version", current_version)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_document_id = result.params.get("document_id", document_id)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ final_page_id = result.params.get("page_id", page_id)
+ final_title = result.params.get("new_title", new_title) or current_title
+ final_content = result.params.get("new_content", new_content)
+ if final_content is None:
+ final_content = current_body
+ final_version = result.params.get("version", current_version)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Confluence connector is invalid.",
- }
+ final_document_id = result.params.get("document_id", document_id)
- try:
- client = ConfluenceHistoryConnector(
- session=db_session, connector_id=final_connector_id
- )
- api_result = await client.update_page(
- page_id=final_page_id,
- title=final_title,
- body=final_content,
- version_number=final_version + 1,
- )
- await client.close()
- except Exception as api_err:
- if (
- "http 403" in str(api_err).lower()
- or "status code 403" in str(api_err).lower()
- ):
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "error",
+ "message": "No connector found for this page.",
}
- raise
- page_links = (
- api_result.get("_links", {}) if isinstance(api_result, dict) else {}
- )
- page_url = ""
- if page_links.get("base") and page_links.get("webui"):
- page_url = f"{page_links['base']}{page_links['webui']}"
-
- kb_message_suffix = ""
- if final_document_id:
- try:
- from app.services.confluence import ConfluenceKBSyncService
-
- kb_service = ConfluenceKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=final_document_id,
- page_id=final_page_id,
- user_id=user_id,
- search_space_id=search_space_id,
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Confluence connector is invalid.",
+ }
+
+ try:
+ client = ConfluenceHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ api_result = await client.update_page(
+ page_id=final_page_id,
+ title=final_title,
+ body=final_content,
+ version_number=final_version + 1,
+ )
+ await client.close()
+ except Exception as api_err:
+ if (
+ "http 403" in str(api_err).lower()
+ or "status code 403" in str(api_err).lower()
+ ):
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ page_links = (
+ api_result.get("_links", {}) if isinstance(api_result, dict) else {}
+ )
+ page_url = ""
+ if page_links.get("base") and page_links.get("webui"):
+ page_url = f"{page_links['base']}{page_links['webui']}"
+
+ kb_message_suffix = ""
+ if final_document_id:
+ try:
+ from app.services.confluence import ConfluenceKBSyncService
+
+ kb_service = ConfluenceKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=final_document_id,
+ page_id=final_page_id,
+ user_id=user_id,
+ search_space_id=search_space_id,
)
- else:
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = (
+ " The knowledge base will be updated in the next sync."
+ )
+ except Exception as kb_err:
+ logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
- except Exception as kb_err:
- logger.warning(f"KB sync after update failed: {kb_err}")
- kb_message_suffix = (
- " The knowledge base will be updated in the next sync."
- )
- return {
- "status": "success",
- "page_id": final_page_id,
- "page_url": page_url,
- "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
- }
+ return {
+ "status": "success",
+ "page_id": final_page_id,
+ "page_url": page_url,
+ "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py
index 5675a42e6..6420a90e6 100644
--- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py
+++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py
@@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
@@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
search_space_id: int,
user_id: str,
) -> StructuredTool:
+ """Factory function to create the get_connected_accounts tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to scope account discovery to.
+ user_id: User ID to scope account discovery to.
+
+ Returns:
+ Configured StructuredTool for connected-accounts discovery.
+ """
+ del db_session # per-call session — see docstring
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
@@ -68,40 +85,41 @@ def create_get_connected_accounts_tool(
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type == connector_type,
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type == connector_type,
+ )
)
- )
- connectors = result.scalars().all()
+ connectors = result.scalars().all()
- if not connectors:
- return [
- {
- "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
+ if not connectors:
+ return [
+ {
+ "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
+ }
+ ]
+
+ is_multi = len(connectors) > 1
+
+ accounts: list[dict[str, Any]] = []
+ for conn in connectors:
+ cfg = conn.config or {}
+ entry: dict[str, Any] = {
+ "connector_id": conn.id,
+ "display_name": _extract_display_name(conn),
+ "service": service,
}
- ]
+ if is_multi:
+ entry["tool_prefix"] = f"{service}_{conn.id}"
+ for key in svc_cfg.account_metadata_keys:
+ if key in cfg:
+ entry[key] = cfg[key]
+ accounts.append(entry)
- is_multi = len(connectors) > 1
-
- accounts: list[dict[str, Any]] = []
- for conn in connectors:
- cfg = conn.config or {}
- entry: dict[str, Any] = {
- "connector_id": conn.id,
- "display_name": _extract_display_name(conn),
- "service": service,
- }
- if is_multi:
- entry["tool_prefix"] = f"{service}_{conn.id}"
- for key in svc_cfg.account_metadata_keys:
- if key in cfg:
- entry[key] = cfg[key]
- accounts.append(entry)
-
- return accounts
+ return accounts
return StructuredTool(
name="get_connected_accounts",
diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py
index 3cc99ac17..01159a261 100644
--- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py
+++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the list_discord_channels tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured list_discord_channels tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server.
@@ -22,59 +41,60 @@ def create_list_discord_channels_tool(
Returns:
Dictionary with status and a list of channels (id, name).
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
}
try:
- connector = await get_discord_connector(
- db_session, search_space_id, user_id
- )
- if not connector:
- return {"status": "error", "message": "No Discord connector found."}
-
- guild_id = get_guild_id(connector)
- if not guild_id:
- return {
- "status": "error",
- "message": "No guild ID in Discord connector config.",
- }
-
- token = get_bot_token(connector)
-
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{DISCORD_API}/guilds/{guild_id}/channels",
- headers={"Authorization": f"Bot {token}"},
- timeout=15.0,
+ async with async_session_maker() as db_session:
+ connector = await get_discord_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Discord connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Discord bot token is invalid.",
- "connector_type": "discord",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Discord API error: {resp.status_code}",
- }
+ guild_id = get_guild_id(connector)
+ if not guild_id:
+ return {
+ "status": "error",
+ "message": "No guild ID in Discord connector config.",
+ }
- # Type 0 = text channel
- channels = [
- {"id": ch["id"], "name": ch["name"]}
- for ch in resp.json()
- if ch.get("type") == 0
- ]
- return {
- "status": "success",
- "guild_id": guild_id,
- "channels": channels,
- "total": len(channels),
- }
+ token = get_bot_token(connector)
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ f"{DISCORD_API}/guilds/{guild_id}/channels",
+ headers={"Authorization": f"Bot {token}"},
+ timeout=15.0,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Discord bot token is invalid.",
+ "connector_type": "discord",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Discord API error: {resp.status_code}",
+ }
+
+ # Type 0 = text channel
+ channels = [
+ {"id": ch["id"], "name": ch["name"]}
+ for ch in resp.json()
+ if ch.get("type") == 0
+ ]
+ return {
+ "status": "success",
+ "guild_id": guild_id,
+ "channels": channels,
+ "total": len(channels),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py
index d8bf989a1..88d6cdd49 100644
--- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py
+++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_discord_messages tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_discord_messages tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_discord_messages(
channel_id: str,
@@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
Dictionary with status and a list of messages including
id, author, content, timestamp.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
@@ -39,55 +58,56 @@ def create_read_discord_messages_tool(
limit = min(limit, 50)
try:
- connector = await get_discord_connector(
- db_session, search_space_id, user_id
- )
- if not connector:
- return {"status": "error", "message": "No Discord connector found."}
-
- token = get_bot_token(connector)
-
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{DISCORD_API}/channels/{channel_id}/messages",
- headers={"Authorization": f"Bot {token}"},
- params={"limit": limit},
- timeout=15.0,
+ async with async_session_maker() as db_session:
+ connector = await get_discord_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Discord connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Discord bot token is invalid.",
- "connector_type": "discord",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Bot lacks permission to read this channel.",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Discord API error: {resp.status_code}",
- }
+ token = get_bot_token(connector)
- messages = [
- {
- "id": m["id"],
- "author": m.get("author", {}).get("username", "Unknown"),
- "content": m.get("content", ""),
- "timestamp": m.get("timestamp", ""),
- }
- for m in resp.json()
- ]
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ f"{DISCORD_API}/channels/{channel_id}/messages",
+ headers={"Authorization": f"Bot {token}"},
+ params={"limit": limit},
+ timeout=15.0,
+ )
- return {
- "status": "success",
- "channel_id": channel_id,
- "messages": messages,
- "total": len(messages),
- }
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Discord bot token is invalid.",
+ "connector_type": "discord",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Bot lacks permission to read this channel.",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Discord API error: {resp.status_code}",
+ }
+
+ messages = [
+ {
+ "id": m["id"],
+ "author": m.get("author", {}).get("username", "Unknown"),
+ "content": m.get("content", ""),
+ "timestamp": m.get("timestamp", ""),
+ }
+ for m in resp.json()
+ ]
+
+ return {
+ "status": "success",
+ "channel_id": channel_id,
+ "messages": messages,
+ "total": len(messages),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py
index 236cd017a..5fe6fde35 100644
--- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py
+++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
@@ -17,6 +18,23 @@ def create_send_discord_message_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the send_discord_message tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured send_discord_message tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def send_discord_message(
channel_id: str,
@@ -34,7 +52,7 @@ def create_send_discord_message_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
@@ -47,64 +65,65 @@ def create_send_discord_message_tool(
}
try:
- connector = await get_discord_connector(
- db_session, search_space_id, user_id
- )
- if not connector:
- return {"status": "error", "message": "No Discord connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_discord_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Discord connector found."}
- result = request_approval(
- action_type="discord_send_message",
- tool_name="send_discord_message",
- params={"channel_id": channel_id, "content": content},
- context={"connector_id": connector.id},
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Message was not sent.",
- }
-
- final_content = result.params.get("content", content)
- final_channel = result.params.get("channel_id", channel_id)
-
- token = get_bot_token(connector)
-
- async with httpx.AsyncClient() as client:
- resp = await client.post(
- f"{DISCORD_API}/channels/{final_channel}/messages",
- headers={
- "Authorization": f"Bot {token}",
- "Content-Type": "application/json",
- },
- json={"content": final_content},
- timeout=15.0,
+ result = request_approval(
+ action_type="discord_send_message",
+ tool_name="send_discord_message",
+ params={"channel_id": channel_id, "content": content},
+ context={"connector_id": connector.id},
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Discord bot token is invalid.",
- "connector_type": "discord",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Bot lacks permission to send messages in this channel.",
- }
- if resp.status_code not in (200, 201):
- return {
- "status": "error",
- "message": f"Discord API error: {resp.status_code}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Message was not sent.",
+ }
- msg_data = resp.json()
- return {
- "status": "success",
- "message_id": msg_data.get("id"),
- "message": f"Message sent to channel {final_channel}.",
- }
+ final_content = result.params.get("content", content)
+ final_channel = result.params.get("channel_id", channel_id)
+
+ token = get_bot_token(connector)
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ f"{DISCORD_API}/channels/{final_channel}/messages",
+ headers={
+ "Authorization": f"Bot {token}",
+ "Content-Type": "application/json",
+ },
+ json={"content": final_content},
+ timeout=15.0,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Discord bot token is invalid.",
+ "connector_type": "discord",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Bot lacks permission to send messages in this channel.",
+ }
+ if resp.status_code not in (200, 201):
+ return {
+ "status": "error",
+ "message": f"Discord API error: {resp.status_code}",
+ }
+
+ msg_data = resp.json()
+ return {
+ "status": "success",
+ "message_id": msg_data.get("id"),
+ "message": f"Message sent to channel {final_channel}.",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py
index 22d8a8a27..7aae034cc 100644
--- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py
@@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_dropbox_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_dropbox_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_dropbox_file(
name: str,
@@ -82,184 +99,191 @@ def create_create_dropbox_file_tool(
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.DROPBOX_CONNECTOR,
- )
- )
- connectors = result.scalars().all()
-
- if not connectors:
- return {
- "status": "error",
- "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
- }
-
- accounts = []
- for c in connectors:
- cfg = c.config or {}
- accounts.append(
- {
- "id": c.id,
- "name": c.name,
- "user_email": cfg.get("user_email"),
- "auth_expired": cfg.get("auth_expired", False),
- }
- )
-
- if all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected Dropbox accounts need re-authentication.",
- "connector_type": "dropbox",
- }
-
- parent_folders: dict[int, list[dict[str, str]]] = {}
- for acc in accounts:
- cid = acc["id"]
- if acc.get("auth_expired"):
- parent_folders[cid] = []
- continue
- try:
- client = DropboxClient(session=db_session, connector_id=cid)
- items, err = await client.list_folder("")
- if err:
- logger.warning(
- "Failed to list folders for connector %s: %s", cid, err
- )
- parent_folders[cid] = []
- else:
- parent_folders[cid] = [
- {
- "folder_path": item.get("path_lower", ""),
- "name": item["name"],
- }
- for item in items
- if item.get(".tag") == "folder" and item.get("name")
- ]
- except Exception:
- logger.warning(
- "Error fetching folders for connector %s", cid, exc_info=True
- )
- parent_folders[cid] = []
-
- context: dict[str, Any] = {
- "accounts": accounts,
- "parent_folders": parent_folders,
- "supported_types": _SUPPORTED_TYPES,
- }
-
- result = request_approval(
- action_type="dropbox_file_creation",
- tool_name="create_dropbox_file",
- params={
- "name": name,
- "file_type": file_type,
- "content": content,
- "connector_id": None,
- "parent_folder_path": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_name = result.params.get("name", name)
- final_file_type = result.params.get("file_type", file_type)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get("connector_id")
- final_parent_folder_path = result.params.get("parent_folder_path")
-
- if not final_name or not final_name.strip():
- return {"status": "error", "message": "File name cannot be empty."}
-
- final_name = _ensure_extension(final_name, final_file_type)
-
- if final_connector_id is not None:
+ async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
- connector = result.scalars().first()
- else:
- connector = connectors[0]
+ connectors = result.scalars().all()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Dropbox connector is invalid.",
+ if not connectors:
+ return {
+ "status": "error",
+ "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
+ }
+
+ accounts = []
+ for c in connectors:
+ cfg = c.config or {}
+ accounts.append(
+ {
+ "id": c.id,
+ "name": c.name,
+ "user_email": cfg.get("user_email"),
+ "auth_expired": cfg.get("auth_expired", False),
+ }
+ )
+
+ if all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected Dropbox accounts need re-authentication.",
+ "connector_type": "dropbox",
+ }
+
+ parent_folders: dict[int, list[dict[str, str]]] = {}
+ for acc in accounts:
+ cid = acc["id"]
+ if acc.get("auth_expired"):
+ parent_folders[cid] = []
+ continue
+ try:
+ client = DropboxClient(session=db_session, connector_id=cid)
+ items, err = await client.list_folder("")
+ if err:
+ logger.warning(
+ "Failed to list folders for connector %s: %s", cid, err
+ )
+ parent_folders[cid] = []
+ else:
+ parent_folders[cid] = [
+ {
+ "folder_path": item.get("path_lower", ""),
+ "name": item["name"],
+ }
+ for item in items
+ if item.get(".tag") == "folder" and item.get("name")
+ ]
+ except Exception:
+ logger.warning(
+ "Error fetching folders for connector %s",
+ cid,
+ exc_info=True,
+ )
+ parent_folders[cid] = []
+
+ context: dict[str, Any] = {
+ "accounts": accounts,
+ "parent_folders": parent_folders,
+ "supported_types": _SUPPORTED_TYPES,
}
- client = DropboxClient(session=db_session, connector_id=connector.id)
-
- parent_path = final_parent_folder_path or ""
- file_path = (
- f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
- )
-
- if final_file_type == "paper":
- created = await client.create_paper_doc(file_path, final_content or "")
- file_id = created.get("file_id", "")
- web_url = created.get("url", "")
- else:
- docx_bytes = _markdown_to_docx(final_content or "")
- created = await client.upload_file(
- file_path, docx_bytes, mode="add", autorename=True
+ result = request_approval(
+ action_type="dropbox_file_creation",
+ tool_name="create_dropbox_file",
+ params={
+ "name": name,
+ "file_type": file_type,
+ "content": content,
+ "connector_id": None,
+ "parent_folder_path": None,
+ },
+ context=context,
)
- file_id = created.get("id", "")
- web_url = ""
- logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
- kb_message_suffix = ""
- try:
- from app.services.dropbox import DropboxKBSyncService
+ final_name = result.params.get("name", name)
+ final_file_type = result.params.get("file_type", file_type)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get("connector_id")
+ final_parent_folder_path = result.params.get("parent_folder_path")
- kb_service = DropboxKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- file_id=file_id,
- file_name=final_name,
- file_path=file_path,
- web_url=web_url,
- content=final_content,
- connector_id=connector.id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ if not final_name or not final_name.strip():
+ return {"status": "error", "message": "File name cannot be empty."}
+
+ final_name = _ensure_extension(final_name, final_file_type)
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.DROPBOX_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
else:
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ connector = connectors[0]
- return {
- "status": "success",
- "file_id": file_id,
- "name": final_name,
- "web_url": web_url,
- "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
- }
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Dropbox connector is invalid.",
+ }
+
+ client = DropboxClient(session=db_session, connector_id=connector.id)
+
+ parent_path = final_parent_folder_path or ""
+ file_path = (
+ f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
+ )
+
+ if final_file_type == "paper":
+ created = await client.create_paper_doc(
+ file_path, final_content or ""
+ )
+ file_id = created.get("file_id", "")
+ web_url = created.get("url", "")
+ else:
+ docx_bytes = _markdown_to_docx(final_content or "")
+ created = await client.upload_file(
+ file_path, docx_bytes, mode="add", autorename=True
+ )
+ file_id = created.get("id", "")
+ web_url = ""
+
+ logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
+
+ kb_message_suffix = ""
+ try:
+ from app.services.dropbox import DropboxKBSyncService
+
+ kb_service = DropboxKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ file_id=file_id,
+ file_name=final_name,
+ file_path=file_path,
+ web_url=web_url,
+ content=final_content,
+ connector_id=connector.id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "file_id": file_id,
+ "name": final_name,
+ "web_url": web_url,
+ "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py
index 12559b57a..0e59e49db 100644
--- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py
@@ -13,6 +13,7 @@ from app.db import (
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
+ async_session_maker,
)
logger = logging.getLogger(__name__)
@@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_dropbox_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_dropbox_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_dropbox_file(
file_name: str,
@@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool(
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
- doc_result = await db_session.execute(
- select(Document)
- .join(
- SearchSourceConnector,
- Document.connector_id == SearchSourceConnector.id,
- )
- .filter(
- and_(
- Document.search_space_id == search_space_id,
- Document.document_type == DocumentType.DROPBOX_FILE,
- func.lower(Document.title) == func.lower(file_name),
- SearchSourceConnector.user_id == user_id,
- )
- )
- .order_by(Document.updated_at.desc().nullslast())
- .limit(1)
- )
- document = doc_result.scalars().first()
-
- if not document:
+ async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
@@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
- func.lower(
- cast(
- Document.document_metadata["dropbox_file_name"],
- String,
- )
- )
- == func.lower(file_name),
+ func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
@@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool(
)
document = doc_result.scalars().first()
- if not document:
- return {
- "status": "not_found",
- "message": (
- f"File '{file_name}' not found in your indexed Dropbox files. "
- "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
- "or (3) the file name is different."
- ),
- }
-
- if not document.connector_id:
- return {
- "status": "error",
- "message": "Document has no associated connector.",
- }
-
- meta = document.document_metadata or {}
- file_path = meta.get("dropbox_path")
- file_id = meta.get("dropbox_file_id")
- document_id = document.id
-
- if not file_path:
- return {
- "status": "error",
- "message": "File path is missing. Please re-index the file.",
- }
-
- conn_result = await db_session.execute(
- select(SearchSourceConnector).filter(
- and_(
- SearchSourceConnector.id == document.connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.DROPBOX_CONNECTOR,
+ if not document:
+ doc_result = await db_session.execute(
+ select(Document)
+ .join(
+ SearchSourceConnector,
+ Document.connector_id == SearchSourceConnector.id,
+ )
+ .filter(
+ and_(
+ Document.search_space_id == search_space_id,
+ Document.document_type == DocumentType.DROPBOX_FILE,
+ func.lower(
+ cast(
+ Document.document_metadata["dropbox_file_name"],
+ String,
+ )
+ )
+ == func.lower(file_name),
+ SearchSourceConnector.user_id == user_id,
+ )
+ )
+ .order_by(Document.updated_at.desc().nullslast())
+ .limit(1)
)
- )
- )
- connector = conn_result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Dropbox connector not found or access denied.",
- }
+ document = doc_result.scalars().first()
- cfg = connector.config or {}
- if cfg.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "dropbox",
- }
+ if not document:
+ return {
+ "status": "not_found",
+ "message": (
+ f"File '{file_name}' not found in your indexed Dropbox files. "
+ "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
+ "or (3) the file name is different."
+ ),
+ }
- context = {
- "file": {
- "file_id": file_id,
- "file_path": file_path,
- "name": file_name,
- "document_id": document_id,
- },
- "account": {
- "id": connector.id,
- "name": connector.name,
- "user_email": cfg.get("user_email"),
- },
- }
+ if not document.connector_id:
+ return {
+ "status": "error",
+ "message": "Document has no associated connector.",
+ }
- result = request_approval(
- action_type="dropbox_file_trash",
- tool_name="delete_dropbox_file",
- params={
- "file_path": file_path,
- "connector_id": connector.id,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ meta = document.document_metadata or {}
+ file_path = meta.get("dropbox_path")
+ file_id = meta.get("dropbox_file_id")
+ document_id = document.id
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
+ if not file_path:
+ return {
+ "status": "error",
+ "message": "File path is missing. Please re-index the file.",
+ }
- final_file_path = result.params.get("file_path", file_path)
- final_connector_id = result.params.get("connector_id", connector.id)
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if final_connector_id != connector.id:
- result = await db_session.execute(
+ conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
- SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
@@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool(
)
)
)
- validated_connector = result.scalars().first()
- if not validated_connector:
+ connector = conn_result.scalars().first()
+ if not connector:
return {
"status": "error",
- "message": "Selected Dropbox connector is invalid or has been disconnected.",
+ "message": "Dropbox connector not found or access denied.",
}
- actual_connector_id = validated_connector.id
- else:
- actual_connector_id = connector.id
- logger.info(
- f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
- )
+ cfg = connector.config or {}
+ if cfg.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "dropbox",
+ }
- client = DropboxClient(session=db_session, connector_id=actual_connector_id)
- await client.delete_file(final_file_path)
+ context = {
+ "file": {
+ "file_id": file_id,
+ "file_path": file_path,
+ "name": file_name,
+ "document_id": document_id,
+ },
+ "account": {
+ "id": connector.id,
+ "name": connector.name,
+ "user_email": cfg.get("user_email"),
+ },
+ }
- logger.info(f"Dropbox file deleted: path={final_file_path}")
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "file_id": file_id,
- "message": f"Successfully deleted '{file_name}' from Dropbox.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- doc = doc_result.scalars().first()
- if doc:
- await db_session.delete(doc)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"File deleted, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ result = request_approval(
+ action_type="dropbox_file_trash",
+ tool_name="delete_dropbox_file",
+ params={
+ "file_path": file_path,
+ "connector_id": connector.id,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- return trash_result
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_file_path = result.params.get("file_path", file_path)
+ final_connector_id = result.params.get("connector_id", connector.id)
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ if final_connector_id != connector.id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ and_(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id
+ == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.DROPBOX_CONNECTOR,
+ )
+ )
+ )
+ validated_connector = result.scalars().first()
+ if not validated_connector:
+ return {
+ "status": "error",
+ "message": "Selected Dropbox connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = validated_connector.id
+ else:
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
+ )
+
+ client = DropboxClient(
+ session=db_session, connector_id=actual_connector_id
+ )
+ await client.delete_file(final_file_path)
+
+ logger.info(f"Dropbox file deleted: path={final_file_path}")
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "file_id": file_id,
+ "message": f"Successfully deleted '{file_name}' from Dropbox.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ doc = doc_result.scalars().first()
+ if doc:
+ await db_session.delete(doc)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"File deleted, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py
index 3803fa39c..9e287ac51 100644
--- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py
+++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py
@@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
+from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
@@ -49,12 +50,16 @@ _PROVIDER_MAP = {
}
+def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
+ if custom_provider:
+ return custom_provider
+ return _PROVIDER_MAP.get(provider.upper(), provider.lower())
+
+
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
- if custom_provider:
- return f"{custom_provider}/{model_name}"
- prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
+ prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{prefix}/{model_name}"
@@ -146,14 +151,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
- model_string = _build_model_string(
- cfg.get("provider", ""),
- cfg["model_name"],
- cfg.get("custom_provider"),
+ provider_prefix = _resolve_provider_prefix(
+ cfg.get("provider", ""), cfg.get("custom_provider")
)
+ model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
- if cfg.get("api_base"):
- gen_kwargs["api_base"] = cfg["api_base"]
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=provider_prefix,
+ config_api_base=cfg.get("api_base"),
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@@ -175,14 +184,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
- model_string = _build_model_string(
- db_cfg.provider.value,
- db_cfg.model_name,
- db_cfg.custom_provider,
+ provider_prefix = _resolve_provider_prefix(
+ db_cfg.provider.value, db_cfg.custom_provider
)
+ model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
- if db_cfg.api_base:
- gen_kwargs["api_base"] = db_cfg.api_base
+ api_base = resolve_api_base(
+ provider=db_cfg.provider.value,
+ provider_prefix=provider_prefix,
+ config_api_base=db_cfg.api_base,
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py
new file mode 100644
index 000000000..0ca1191a4
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py
@@ -0,0 +1,41 @@
+from typing import Any
+
+from app.db import SearchSourceConnector
+from app.services.composio_service import ComposioService
+
+
+def split_recipients(value: str | None) -> list[str]:
+ if not value:
+ return []
+ return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
+
+
+def unwrap_composio_data(data: Any) -> Any:
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner)
+ return inner
+ return data
+
+
+async def execute_composio_gmail_tool(
+ connector: SearchSourceConnector,
+ user_id: str,
+ tool_name: str,
+ params: dict[str, Any],
+) -> tuple[Any, str | None]:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return None, "Composio connected account ID not found for this Gmail connector."
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Gmail error")
+
+ return unwrap_composio_data(result.get("data")), None
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
index 0bd044695..c88b48d2d 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_gmail_draft tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_gmail_draft tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_gmail_draft(
to: str,
@@ -57,246 +75,276 @@ def create_create_gmail_draft_tool(
"""
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Gmail accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
-
- logger.info(
- f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
- )
- result = request_approval(
- action_type="gmail_draft_creation",
- tool_name="create_gmail_draft",
- params={
- "to": to,
- "subject": subject,
- "body": body,
- "cc": cc,
- "bcc": bcc,
- "connector_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
- }
-
- final_to = result.params.get("to", to)
- final_subject = result.params.get("subject", subject)
- final_body = result.params.get("body", body)
- final_cc = result.params.get("cc", cc)
- final_bcc = result.params.get("bcc", bcc)
- final_connector_id = result.params.get("connector_id")
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
- )
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
+
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning("All Gmail accounts have expired authentication")
return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ "status": "auth_error",
+ "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
}
- actual_connector_id = connector.id
- logger.info(
- f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
- )
+ logger.info(
+ f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
+ )
+ result = request_approval(
+ action_type="gmail_draft_creation",
+ tool_name="create_gmail_draft",
+ params={
+ "to": to,
+ "subject": subject,
+ "body": body,
+ "cc": cc,
+ "bcc": bcc,
+ "connector_id": None,
+ },
+ context=context,
+ )
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
+ }
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
+ final_to = result.params.get("to", to)
+ final_subject = result.params.get("subject", subject)
+ final_body = result.params.get("body", body)
+ final_cc = result.params.get("cc", cc)
+ final_bcc = result.params.get("bcc", bcc)
+ final_connector_id = result.params.get("connector_id")
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
else:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
- }
- else:
- from google.oauth2.credentials import Credentials
-
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
)
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
- )
-
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
- message = MIMEText(final_body)
- message["to"] = final_to
- message["subject"] = final_subject
- if final_cc:
- message["cc"] = final_cc
- if final_bcc:
- message["bcc"] = final_bcc
- raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
-
- try:
- created = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .drafts()
- .create(userId="me", body={"message": {"raw": raw}})
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ }
+ actual_connector_id = connector.id
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
+ logger.info(
+ f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
+ )
+
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
+ if is_composio_gmail:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
+ else:
+ from google.oauth2.credentials import Credentials
+
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
)
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ message = MIMEText(final_body)
+ message["to"] = final_to
+ message["subject"] = final_subject
+ if final_cc:
+ message["cc"] = final_cc
+ if final_bcc:
+ message["bcc"] = final_bcc
+ raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
)
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
+
+ created, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_CREATE_EMAIL_DRAFT",
+ {
+ "user_id": "me",
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(created, dict):
+ created = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ created = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .drafts()
+ .create(userId="me", body={"message": {"raw": raw}})
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
- logger.info(f"Gmail draft created: id={created.get('id')}")
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- kb_message_suffix = ""
- try:
- from app.services.gmail import GmailKBSyncService
+ logger.info(f"Gmail draft created: id={created.get('id')}")
- kb_service = GmailKBSyncService(db_session)
- draft_message = created.get("message", {})
- kb_result = await kb_service.sync_after_create(
- message_id=draft_message.get("id", ""),
- thread_id=draft_message.get("threadId", ""),
- subject=final_subject,
- sender="me",
- date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- body_text=final_body,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- draft_id=created.get("id"),
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
+ kb_message_suffix = ""
+ try:
+ from app.services.gmail import GmailKBSyncService
+
+ kb_service = GmailKBSyncService(db_session)
+ draft_message = created.get("message", {})
+ kb_result = await kb_service.sync_after_create(
+ message_id=draft_message.get("id", ""),
+ thread_id=draft_message.get("threadId", ""),
+ subject=final_subject,
+ sender="me",
+ date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ body_text=final_body,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ draft_id=created.get("id"),
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "draft_id": created.get("id"),
- "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
- }
+ return {
+ "status": "success",
+ "draft_id": created.get("id"),
+ "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
index deec1627c..464713591 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
@@ -5,7 +5,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_gmail_email tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_gmail_email tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID.
@@ -32,60 +49,115 @@ def create_read_gmail_email_tool(
Returns:
Dictionary with status and the full email content formatted as markdown.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
- }
-
- from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
-
- creds = _build_credentials(connector)
-
- from app.connectors.google_gmail_connector import GoogleGmailConnector
-
- gmail = GoogleGmailConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
-
- detail, error = await gmail.get_message_details(message_id)
- if error:
- if (
- "re-authenticate" in error.lower()
- or "authentication failed" in error.lower()
- ):
+ connector = result.scalars().first()
+ if not connector:
return {
- "status": "auth_error",
- "message": error,
- "connector_type": "gmail",
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
- return {"status": "error", "message": error}
- if not detail:
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ ):
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found.",
+ }
+
+ from app.agents.new_chat.tools.gmail.search_emails import (
+ _format_gmail_summary,
+ )
+ from app.services.composio_service import ComposioService
+
+ service = ComposioService()
+ detail, error = await service.get_gmail_message_detail(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ message_id=message_id,
+ )
+ if error:
+ return {"status": "error", "message": error}
+ if not detail:
+ return {
+ "status": "not_found",
+ "message": f"Email with ID '{message_id}' not found.",
+ }
+
+ summary = _format_gmail_summary(detail)
+ content = (
+ f"# {summary['subject']}\n\n"
+ f"**From:** {summary['from']}\n"
+ f"**To:** {summary['to']}\n"
+ f"**Date:** {summary['date']}\n\n"
+ f"## Message Content\n\n"
+ f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
+ f"## Message Details\n\n"
+ f"- **Message ID:** {summary['message_id']}\n"
+ f"- **Thread ID:** {summary['thread_id']}\n"
+ )
+ return {
+ "status": "success",
+ "message_id": summary["message_id"] or message_id,
+ "content": content,
+ }
+
+ from app.agents.new_chat.tools.gmail.search_emails import (
+ _build_credentials,
+ )
+
+ creds = _build_credentials(connector)
+
+ from app.connectors.google_gmail_connector import GoogleGmailConnector
+
+ gmail = GoogleGmailConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
+ )
+
+ detail, error = await gmail.get_message_details(message_id)
+ if error:
+ if (
+ "re-authenticate" in error.lower()
+ or "authentication failed" in error.lower()
+ ):
+ return {
+ "status": "auth_error",
+ "message": error,
+ "connector_type": "gmail",
+ }
+ return {"status": "error", "message": error}
+
+ if not detail:
+ return {
+ "status": "not_found",
+ "message": f"Email with ID '{message_id}' not found.",
+ }
+
+ content = gmail.format_message_to_markdown(detail)
+
return {
- "status": "not_found",
- "message": f"Email with ID '{message_id}' not found.",
+ "status": "success",
+ "message_id": message_id,
+ "content": content,
}
- content = gmail.format_message_to_markdown(detail)
-
- return {"status": "success", "message_id": message_id, "content": content}
-
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
index 2e363609e..3ce154c53 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
@@ -6,7 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- raise ValueError("Composio connected account ID not found.")
- return build_composio_credentials(cca_id)
+ raise ValueError("Composio connectors must use Composio tool execution.")
from google.oauth2.credentials import Credentials
@@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
)
+def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
+ headers = message.get("payload", {}).get("headers", [])
+ return {
+ header.get("name", "").lower(): header.get("value", "")
+ for header in headers
+ if isinstance(header, dict)
+ }
+
+
+def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
+ headers = _gmail_headers(message)
+ return {
+ "message_id": message.get("id") or message.get("messageId"),
+ "thread_id": message.get("threadId"),
+ "subject": message.get("subject") or headers.get("subject", "No Subject"),
+ "from": message.get("sender") or headers.get("from", "Unknown"),
+ "to": message.get("to") or headers.get("to", ""),
+ "date": message.get("messageTimestamp") or headers.get("date", ""),
+ "snippet": message.get("snippet") or message.get("messageText", "")[:300],
+ "labels": message.get("labelIds", []),
+ }
+
+
+async def _search_composio_gmail(
+ connector: SearchSourceConnector,
+ user_id: str,
+ query: str,
+ max_results: int,
+) -> dict[str, Any]:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found.",
+ }
+
+ from app.services.composio_service import ComposioService
+
+ service = ComposioService()
+ messages, _next_token, _estimate, error = await service.get_gmail_messages(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ query=query,
+ max_results=max_results,
+ )
+ if error:
+ return {"status": "error", "message": error}
+
+ emails = [_format_gmail_summary(message) for message in messages]
+ return {
+ "status": "success",
+ "emails": emails,
+ "total": len(emails),
+ "message": "No emails found." if not emails else None,
+ }
+
+
def create_search_gmail_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the search_gmail tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured search_gmail tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def search_gmail(
query: str,
@@ -90,83 +159,92 @@ def create_search_gmail_tool(
Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20)
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
- }
-
- creds = _build_credentials(connector)
-
- from app.connectors.google_gmail_connector import GoogleGmailConnector
-
- gmail = GoogleGmailConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
-
- messages_list, error = await gmail.get_messages_list(
- max_results=max_results, query=query
- )
- if error:
- if (
- "re-authenticate" in error.lower()
- or "authentication failed" in error.lower()
- ):
+ connector = result.scalars().first()
+ if not connector:
return {
- "status": "auth_error",
- "message": error,
- "connector_type": "gmail",
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
- return {"status": "error", "message": error}
- if not messages_list:
- return {
- "status": "success",
- "emails": [],
- "total": 0,
- "message": "No emails found.",
- }
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ ):
+ return await _search_composio_gmail(
+ connector, str(user_id), query, max_results
+ )
- emails = []
- for msg in messages_list:
- detail, err = await gmail.get_message_details(msg["id"])
- if err:
- continue
- headers = {
- h["name"].lower(): h["value"]
- for h in detail.get("payload", {}).get("headers", [])
- }
- emails.append(
- {
- "message_id": detail.get("id"),
- "thread_id": detail.get("threadId"),
- "subject": headers.get("subject", "No Subject"),
- "from": headers.get("from", "Unknown"),
- "to": headers.get("to", ""),
- "date": headers.get("date", ""),
- "snippet": detail.get("snippet", ""),
- "labels": detail.get("labelIds", []),
- }
+ creds = _build_credentials(connector)
+
+ from app.connectors.google_gmail_connector import GoogleGmailConnector
+
+ gmail = GoogleGmailConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
)
- return {"status": "success", "emails": emails, "total": len(emails)}
+ messages_list, error = await gmail.get_messages_list(
+ max_results=max_results, query=query
+ )
+ if error:
+ if (
+ "re-authenticate" in error.lower()
+ or "authentication failed" in error.lower()
+ ):
+ return {
+ "status": "auth_error",
+ "message": error,
+ "connector_type": "gmail",
+ }
+ return {"status": "error", "message": error}
+
+ if not messages_list:
+ return {
+ "status": "success",
+ "emails": [],
+ "total": 0,
+ "message": "No emails found.",
+ }
+
+ emails = []
+ for msg in messages_list:
+ detail, err = await gmail.get_message_details(msg["id"])
+ if err:
+ continue
+ headers = {
+ h["name"].lower(): h["value"]
+ for h in detail.get("payload", {}).get("headers", [])
+ }
+ emails.append(
+ {
+ "message_id": detail.get("id"),
+ "thread_id": detail.get("threadId"),
+ "subject": headers.get("subject", "No Subject"),
+ "from": headers.get("from", "Unknown"),
+ "to": headers.get("to", ""),
+ "date": headers.get("date", ""),
+ "snippet": detail.get("snippet", ""),
+ "labels": detail.get("labelIds", []),
+ }
+ )
+
+ return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
index c3f0999f4..4d5aa3bcc 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the send_gmail_email tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured send_gmail_email tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def send_gmail_email(
to: str,
@@ -58,247 +76,277 @@ def create_send_gmail_email_tool(
"""
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Gmail accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
-
- logger.info(
- f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
- )
- result = request_approval(
- action_type="gmail_email_send",
- tool_name="send_gmail_email",
- params={
- "to": to,
- "subject": subject,
- "body": body,
- "cc": cc,
- "bcc": bcc,
- "connector_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
- }
-
- final_to = result.params.get("to", to)
- final_subject = result.params.get("subject", subject)
- final_body = result.params.get("body", body)
- final_cc = result.params.get("cc", cc)
- final_bcc = result.params.get("bcc", bcc)
- final_connector_id = result.params.get("connector_id")
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
- )
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
+
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning("All Gmail accounts have expired authentication")
return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ "status": "auth_error",
+ "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
}
- actual_connector_id = connector.id
- logger.info(
- f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
- )
+ logger.info(
+ f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
+ )
+ result = request_approval(
+ action_type="gmail_email_send",
+ tool_name="send_gmail_email",
+ params={
+ "to": to,
+ "subject": subject,
+ "body": body,
+ "cc": cc,
+ "bcc": bcc,
+ "connector_id": None,
+ },
+ context=context,
+ )
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
+ }
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
+ final_to = result.params.get("to", to)
+ final_subject = result.params.get("subject", subject)
+ final_body = result.params.get("body", body)
+ final_cc = result.params.get("cc", cc)
+ final_bcc = result.params.get("bcc", bcc)
+ final_connector_id = result.params.get("connector_id")
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
else:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
- }
- else:
- from google.oauth2.credentials import Credentials
-
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
)
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
- )
-
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
- message = MIMEText(final_body)
- message["to"] = final_to
- message["subject"] = final_subject
- if final_cc:
- message["cc"] = final_cc
- if final_bcc:
- message["bcc"] = final_bcc
- raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
-
- try:
- sent = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .messages()
- .send(userId="me", body={"raw": raw})
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ }
+ actual_connector_id = connector.id
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
+ logger.info(
+ f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
+ )
+
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
+ if is_composio_gmail:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
+ else:
+ from google.oauth2.credentials import Credentials
+
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
)
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ message = MIMEText(final_body)
+ message["to"] = final_to
+ message["subject"] = final_subject
+ if final_cc:
+ message["cc"] = final_cc
+ if final_bcc:
+ message["bcc"] = final_bcc
+ raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
)
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
+
+ sent, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_SEND_EMAIL",
+ {
+ "user_id": "me",
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(sent, dict):
+ sent = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ sent = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .messages()
+ .send(userId="me", body={"raw": raw})
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
- logger.info(
- f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
- )
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- kb_message_suffix = ""
- try:
- from app.services.gmail import GmailKBSyncService
-
- kb_service = GmailKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- message_id=sent.get("id", ""),
- thread_id=sent.get("threadId", ""),
- subject=final_subject,
- sender="me",
- date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- body_text=final_body,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ logger.info(
+ f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after send failed: {kb_err}")
- kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "message_id": sent.get("id"),
- "thread_id": sent.get("threadId"),
- "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
- }
+ kb_message_suffix = ""
+ try:
+ from app.services.gmail import GmailKBSyncService
+
+ kb_service = GmailKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ message_id=sent.get("id", ""),
+ thread_id=sent.get("threadId", ""),
+ subject=final_subject,
+ sender="me",
+ date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ body_text=final_body,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after send failed: {kb_err}")
+ kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "message_id": sent.get("id"),
+ "thread_id": sent.get("threadId"),
+ "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
index 1f1f6227a..95f5b4e6c 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
@@ -7,6 +7,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the trash_gmail_email tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured trash_gmail_email tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def trash_gmail_email(
email_subject_or_id: str,
@@ -55,244 +73,261 @@ def create_trash_gmail_email_tool(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_trash_context(
- search_space_id, user_id, email_subject_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Email not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch trash context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Gmail account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_trash_context(
+ search_space_id, user_id, email_subject_or_id
)
- return {
- "status": "auth_error",
- "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
- email = context["email"]
- message_id = email["message_id"]
- document_id = email.get("document_id")
- connector_id_from_context = context["account"]["id"]
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Email not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch trash context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if not message_id:
- return {
- "status": "error",
- "message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
- }
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Gmail account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
+ }
- logger.info(
- f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="gmail_email_trash",
- tool_name="trash_gmail_email",
- params={
- "message_id": message_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ email = context["email"]
+ message_id = email["message_id"]
+ document_id = email.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
- }
-
- final_message_id = result.params.get("message_id", message_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this email.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
- }
-
- logger.info(
- f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
- )
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not message_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
+ "message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
}
- else:
- from google.oauth2.credentials import Credentials
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
- )
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="gmail_email_trash",
+ tool_name="trash_gmail_email",
+ params={
+ "message_id": message_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
- try:
- await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .messages()
- .trash(userId="me", id=final_message_id)
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {connector.id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- if not connector.config.get("auth_expired"):
- connector.config = {
- **connector.config,
- "auth_expired": True,
- }
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- connector.id,
- exc_info=True,
- )
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": connector.id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
- raise
- logger.info(f"Gmail email trashed: message_id={final_message_id}")
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "message_id": final_message_id,
- "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
-
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"Email trashed, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ final_message_id = result.params.get("message_id", message_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
)
- return trash_result
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this email.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+
+ logger.info(
+ f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
+ )
+
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
+ if is_composio_gmail:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
+ else:
+ from google.oauth2.credentials import Credentials
+
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ )
+
+ _trashed, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_MOVE_TO_TRASH",
+ {"user_id": "me", "message_id": final_message_id},
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .messages()
+ .trash(userId="me", id=final_message_id)
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {connector.id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ if not connector.config.get("auth_expired"):
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": connector.id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(f"Gmail email trashed: message_id={final_message_id}")
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "message_id": final_message_id,
+ "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"Email trashed, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
index 91178cd21..129b7defb 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the update_gmail_draft tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured update_gmail_draft tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_gmail_draft(
draft_subject_or_id: str,
@@ -76,294 +94,329 @@ def create_update_gmail_draft_tool(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, draft_subject_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Draft not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Gmail account %s has expired authentication",
- account.get("id"),
- )
- return {
- "status": "auth_error",
- "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
-
- email = context["email"]
- message_id = email["message_id"]
- document_id = email.get("document_id")
- connector_id_from_context = account["id"]
- draft_id_from_context = context.get("draft_id")
-
- original_subject = email.get("subject", draft_subject_or_id)
- final_subject_default = subject if subject else original_subject
- final_to_default = to if to else ""
-
- logger.info(
- f"Requesting approval for updating Gmail draft: '{original_subject}' "
- f"(message_id={message_id}, draft_id={draft_id_from_context})"
- )
- result = request_approval(
- action_type="gmail_draft_update",
- tool_name="update_gmail_draft",
- params={
- "message_id": message_id,
- "draft_id": draft_id_from_context,
- "to": final_to_default,
- "subject": final_subject_default,
- "body": body,
- "cc": cc,
- "bcc": bcc,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
- }
-
- final_to = result.params.get("to", final_to_default)
- final_subject = result.params.get("subject", final_subject_default)
- final_body = result.params.get("body", body)
- final_cc = result.params.get("cc", cc)
- final_bcc = result.params.get("bcc", bcc)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_draft_id = result.params.get("draft_id", draft_id_from_context)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this draft.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
- }
-
- logger.info(
- f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
- )
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
- }
- else:
- from google.oauth2.credentials import Credentials
-
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
- )
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, draft_subject_or_id
)
- from googleapiclient.discovery import build
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Draft not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- gmail_service = build("gmail", "v1", credentials=creds)
-
- # Resolve draft_id if not already available
- if not final_draft_id:
- logger.info(
- f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
- )
- final_draft_id = await _find_draft_id_by_message(
- gmail_service, message_id
- )
-
- if not final_draft_id:
- return {
- "status": "error",
- "message": (
- "Could not find this draft in Gmail. "
- "It may have already been sent or deleted."
- ),
- }
-
- message = MIMEText(final_body)
- if final_to:
- message["to"] = final_to
- message["subject"] = final_subject
- if final_cc:
- message["cc"] = final_cc
- if final_bcc:
- message["bcc"] = final_bcc
- raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
-
- try:
- updated = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .drafts()
- .update(
- userId="me",
- id=final_draft_id,
- body={"message": {"raw": raw}},
- )
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ account = context.get("account", {})
+ if account.get("auth_expired"):
logger.warning(
- f"Insufficient permissions for connector {connector.id}: {api_err}"
+ "Gmail account %s has expired authentication",
+ account.get("id"),
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- if not connector.config.get("auth_expired"):
- connector.config = {
- **connector.config,
- "auth_expired": True,
- }
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- connector.id,
- exc_info=True,
- )
return {
- "status": "insufficient_permissions",
- "connector_id": connector.id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "auth_error",
+ "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
}
- if isinstance(api_err, HttpError) and api_err.resp.status == 404:
+
+ email = context["email"]
+ message_id = email["message_id"]
+ document_id = email.get("document_id")
+ connector_id_from_context = account["id"]
+ draft_id_from_context = context.get("draft_id")
+
+ original_subject = email.get("subject", draft_subject_or_id)
+ final_subject_default = subject if subject else original_subject
+ final_to_default = to if to else ""
+
+ logger.info(
+ f"Requesting approval for updating Gmail draft: '{original_subject}' "
+ f"(message_id={message_id}, draft_id={draft_id_from_context})"
+ )
+ result = request_approval(
+ action_type="gmail_draft_update",
+ tool_name="update_gmail_draft",
+ params={
+ "message_id": message_id,
+ "draft_id": draft_id_from_context,
+ "to": final_to_default,
+ "subject": final_subject_default,
+ "body": body,
+ "cc": cc,
+ "bcc": bcc,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
+ }
+
+ final_to = result.params.get("to", final_to_default)
+ final_subject = result.params.get("subject", final_subject_default)
+ final_body = result.params.get("body", body)
+ final_cc = result.params.get("cc", cc)
+ final_bcc = result.params.get("bcc", bcc)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_draft_id = result.params.get("draft_id", draft_id_from_context)
+
+ if not final_connector_id:
return {
"status": "error",
- "message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
+ "message": "No connector found for this draft.",
}
- raise
- logger.info(f"Gmail draft updated: id={updated.get('id')}")
+ from sqlalchemy.future import select
- kb_message_suffix = ""
- if document_id:
- try:
- from sqlalchemy.future import select as sa_select
- from sqlalchemy.orm.attributes import flag_modified
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- from app.db import Document
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
- doc_result = await db_session.execute(
- sa_select(Document).filter(Document.id == document_id)
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
)
- document = doc_result.scalars().first()
- if document:
- document.source_markdown = final_body
- document.title = final_subject
- meta = dict(document.document_metadata or {})
- meta["subject"] = final_subject
- meta["draft_id"] = updated.get("id", final_draft_id)
- updated_msg = updated.get("message", {})
- if updated_msg.get("id"):
- meta["message_id"] = updated_msg["id"]
- document.document_metadata = meta
- flag_modified(document, "document_metadata")
- await db_session.commit()
- kb_message_suffix = (
- " Your knowledge base has also been updated."
- )
- logger.info(
- f"KB document {document_id} updated for draft {final_draft_id}"
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+
+ logger.info(
+ f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
+ )
+
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
+ if is_composio_gmail:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
+ else:
+ from google.oauth2.credentials import Credentials
+
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ # Resolve draft_id if not already available
+ if not final_draft_id:
+ logger.info(
+ f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
+ )
+ if is_composio_gmail:
+ final_draft_id = await _find_composio_draft_id_by_message(
+ connector, user_id, message_id
)
else:
- kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB update after draft edit failed: {kb_err}")
- await db_session.rollback()
- kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
+ from googleapiclient.discovery import build
- return {
- "status": "success",
- "draft_id": updated.get("id"),
- "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
- }
+ gmail_service = build("gmail", "v1", credentials=creds)
+ final_draft_id = await _find_draft_id_by_message(
+ gmail_service, message_id
+ )
+
+ if not final_draft_id:
+ return {
+ "status": "error",
+ "message": (
+ "Could not find this draft in Gmail. "
+ "It may have already been sent or deleted."
+ ),
+ }
+
+ message = MIMEText(final_body)
+ if final_to:
+ message["to"] = final_to
+ message["subject"] = final_subject
+ if final_cc:
+ message["cc"] = final_cc
+ if final_bcc:
+ message["bcc"] = final_bcc
+ raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
+ )
+
+ updated, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_UPDATE_DRAFT",
+ {
+ "user_id": "me",
+ "draft_id": final_draft_id,
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(updated, dict):
+ updated = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ updated = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .drafts()
+ .update(
+ userId="me",
+ id=final_draft_id,
+ body={"message": {"raw": raw}},
+ )
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {connector.id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ if not connector.config.get("auth_expired"):
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": connector.id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ if isinstance(api_err, HttpError) and api_err.resp.status == 404:
+ return {
+ "status": "error",
+ "message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
+ }
+ raise
+
+ logger.info(f"Gmail draft updated: id={updated.get('id')}")
+
+ kb_message_suffix = ""
+ if document_id:
+ try:
+ from sqlalchemy.future import select as sa_select
+ from sqlalchemy.orm.attributes import flag_modified
+
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ sa_select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ document.source_markdown = final_body
+ document.title = final_subject
+ meta = dict(document.document_metadata or {})
+ meta["subject"] = final_subject
+ meta["draft_id"] = updated.get("id", final_draft_id)
+ updated_msg = updated.get("message", {})
+ if updated_msg.get("id"):
+ meta["message_id"] = updated_msg["id"]
+ document.document_metadata = meta
+ flag_modified(document, "document_metadata")
+ await db_session.commit()
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ logger.info(
+ f"KB document {document_id} updated for draft {final_draft_id}"
+ )
+ else:
+ kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB update after draft edit failed: {kb_err}")
+ await db_session.rollback()
+ kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "draft_id": updated.get("id"),
+ "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
@@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}")
return None
+
+
+async def _find_composio_draft_id_by_message(
+ connector: Any, user_id: str, message_id: str
+) -> str | None:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ )
+
+ page_token = ""
+ while True:
+ params: dict[str, Any] = {
+ "user_id": "me",
+ "max_results": 100,
+ "verbose": False,
+ }
+ if page_token:
+ params["page_token"] = page_token
+
+ data, error = await execute_composio_gmail_tool(
+ connector, user_id, "GMAIL_LIST_DRAFTS", params
+ )
+ if error or not isinstance(data, dict):
+ return None
+
+ for draft in data.get("drafts", []):
+ if draft.get("message", {}).get("id") == message_id:
+ return draft.get("id")
+
+ page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
+ if not page_token:
+ return None
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
index 37bcf083e..dec92cc8b 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_calendar_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_calendar_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_calendar_event(
summary: str,
@@ -60,254 +78,294 @@ def create_create_calendar_event_tool(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleCalendarToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning(
- "All Google Calendar accounts have expired authentication"
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleCalendarToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- return {
- "status": "auth_error",
- "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_calendar",
- }
- logger.info(
- f"Requesting approval for creating calendar event: summary='{summary}'"
- )
- result = request_approval(
- action_type="google_calendar_event_creation",
- tool_name="create_calendar_event",
- params={
- "summary": summary,
- "start_datetime": start_datetime,
- "end_datetime": end_datetime,
- "description": description,
- "location": location,
- "attendees": attendees,
- "timezone": context.get("timezone"),
- "connector_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
- }
-
- final_summary = result.params.get("summary", summary)
- final_start_datetime = result.params.get("start_datetime", start_datetime)
- final_end_datetime = result.params.get("end_datetime", end_datetime)
- final_description = result.params.get("description", description)
- final_location = result.params.get("location", location)
- final_attendees = result.params.get("attendees", attendees)
- final_connector_id = result.params.get("connector_id")
-
- if not final_summary or not final_summary.strip():
- return {"status": "error", "message": "Event summary cannot be empty."}
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _calendar_types = [
- SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Calendar connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning(
+ "All Google Calendar accounts have expired authentication"
)
+ return {
+ "status": "auth_error",
+ "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_calendar",
+ }
+
+ logger.info(
+ f"Requesting approval for creating calendar event: summary='{summary}'"
)
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
- }
- actual_connector_id = connector.id
-
- logger.info(
- f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
- )
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this connector.",
- }
- else:
- config_data = dict(connector.config)
-
- from app.config import config as app_config
- from app.utils.oauth_security import TokenEncryption
-
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and app_config.SECRET_KEY:
- token_encryption = TokenEncryption(app_config.SECRET_KEY)
- for key in ("token", "refresh_token", "client_secret"):
- if config_data.get(key):
- config_data[key] = token_encryption.decrypt_token(
- config_data[key]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ result = request_approval(
+ action_type="google_calendar_event_creation",
+ tool_name="create_calendar_event",
+ params={
+ "summary": summary,
+ "start_datetime": start_datetime,
+ "end_datetime": end_datetime,
+ "description": description,
+ "location": location,
+ "attendees": attendees,
+ "timezone": context.get("timezone"),
+ "connector_id": None,
+ },
+ context=context,
)
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
+ }
- tz = context.get("timezone", "UTC")
- event_body: dict[str, Any] = {
- "summary": final_summary,
- "start": {"dateTime": final_start_datetime, "timeZone": tz},
- "end": {"dateTime": final_end_datetime, "timeZone": tz},
- }
- if final_description:
- event_body["description"] = final_description
- if final_location:
- event_body["location"] = final_location
- if final_attendees:
- event_body["attendees"] = [
- {"email": e.strip()} for e in final_attendees if e.strip()
+ final_summary = result.params.get("summary", summary)
+ final_start_datetime = result.params.get(
+ "start_datetime", start_datetime
+ )
+ final_end_datetime = result.params.get("end_datetime", end_datetime)
+ final_description = result.params.get("description", description)
+ final_location = result.params.get("location", location)
+ final_attendees = result.params.get("attendees", attendees)
+ final_connector_id = result.params.get("connector_id")
+
+ if not final_summary or not final_summary.strip():
+ return {
+ "status": "error",
+ "message": "Event summary cannot be empty.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _calendar_types = [
+ SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
- try:
- created = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .insert(calendarId="primary", body=event_body)
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
+ )
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
-
- logger.info(
- f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.google_calendar import GoogleCalendarKBSyncService
-
- kb_service = GoogleCalendarKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- event_id=created.get("id"),
- event_summary=final_summary,
- calendar_id="primary",
- start_time=final_start_datetime,
- end_time=final_end_datetime,
- location=final_location,
- html_link=created.get("htmlLink"),
- description=final_description,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Calendar connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
else:
- kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
+ }
+ actual_connector_id = connector.id
- return {
- "status": "success",
- "event_id": created.get("id"),
- "html_link": created.get("htmlLink"),
- "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
- }
+ logger.info(
+ f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
+ )
+
+ is_composio_calendar = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ )
+ if is_composio_calendar:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
+ else:
+ config_data = dict(connector.config)
+
+ from app.config import config as app_config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and app_config.SECRET_KEY:
+ token_encryption = TokenEncryption(app_config.SECRET_KEY)
+ for key in ("token", "refresh_token", "client_secret"):
+ if config_data.get(key):
+ config_data[key] = token_encryption.decrypt_token(
+ config_data[key]
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ tz = context.get("timezone", "UTC")
+ event_body: dict[str, Any] = {
+ "summary": final_summary,
+ "start": {"dateTime": final_start_datetime, "timeZone": tz},
+ "end": {"dateTime": final_end_datetime, "timeZone": tz},
+ }
+ if final_description:
+ event_body["description"] = final_description
+ if final_location:
+ event_body["location"] = final_location
+ if final_attendees:
+ event_body["attendees"] = [
+ {"email": e.strip()} for e in final_attendees if e.strip()
+ ]
+
+ try:
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_params = {
+ "calendar_id": "primary",
+ "summary": final_summary,
+ "start_datetime": final_start_datetime,
+ "end_datetime": final_end_datetime,
+ "timezone": tz,
+ "attendees": final_attendees or [],
+ }
+ if final_description:
+ composio_params["description"] = final_description
+ if final_location:
+ composio_params["location"] = final_location
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_CREATE_EVENT",
+ params=composio_params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ created = composio_result.get("data", {})
+ if isinstance(created, dict):
+ created = created.get("data", created)
+ if isinstance(created, dict):
+ created = created.get("response_data", created)
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ created = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .insert(calendarId="primary", body=event_body)
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(
+ f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.google_calendar import GoogleCalendarKBSyncService
+
+ kb_service = GoogleCalendarKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ event_id=created.get("id"),
+ event_summary=final_summary,
+ calendar_id="primary",
+ start_time=final_start_datetime,
+ end_time=final_end_datetime,
+ location=final_location,
+ html_link=created.get("htmlLink"),
+ description=final_description,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "event_id": created.get("id"),
+ "html_link": created.get("htmlLink"),
+ "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
index 4d9d69b4b..e7e891b08 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_calendar_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_calendar_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_calendar_event(
event_title_or_id: str,
@@ -54,240 +72,258 @@ def create_delete_calendar_event_tool(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleCalendarToolMetadataService(db_session)
- context = await metadata_service.get_deletion_context(
- search_space_id, user_id, event_title_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Event not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch deletion context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Google Calendar account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleCalendarToolMetadataService(db_session)
+ context = await metadata_service.get_deletion_context(
+ search_space_id, user_id, event_title_or_id
)
- return {
- "status": "auth_error",
- "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_calendar",
- }
- event = context["event"]
- event_id = event["event_id"]
- document_id = event.get("document_id")
- connector_id_from_context = context["account"]["id"]
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Event not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch deletion context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if not event_id:
- return {
- "status": "error",
- "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
- }
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Google Calendar account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_calendar",
+ }
- logger.info(
- f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="google_calendar_event_deletion",
- tool_name="delete_calendar_event",
- params={
- "event_id": event_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ event = context["event"]
+ event_id = event["event_id"]
+ document_id = event.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
- }
-
- final_event_id = result.params.get("event_id", event_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this event.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _calendar_types = [
- SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Calendar connector is invalid or has been disconnected.",
- }
-
- actual_connector_id = connector.id
-
- logger.info(
- f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
- )
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not event_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this connector.",
+ "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
- else:
- config_data = dict(connector.config)
- from app.config import config as app_config
- from app.utils.oauth_security import TokenEncryption
-
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and app_config.SECRET_KEY:
- token_encryption = TokenEncryption(app_config.SECRET_KEY)
- for key in ("token", "refresh_token", "client_secret"):
- if config_data.get(key):
- config_data[key] = token_encryption.decrypt_token(
- config_data[key]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="google_calendar_event_deletion",
+ tool_name="delete_calendar_event",
+ params={
+ "event_id": event_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
-
- try:
- await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .delete(calendarId="primary", eventId=final_event_id)
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
- raise
- logger.info(f"Calendar event deleted: event_id={final_event_id}")
-
- delete_result: dict[str, Any] = {
- "status": "success",
- "event_id": final_event_id,
- "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
-
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- delete_result["warning"] = (
- f"Event deleted, but failed to remove from knowledge base: {e!s}"
- )
-
- delete_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- delete_result["message"] = (
- f"{delete_result.get('message', '')} (also removed from knowledge base)"
+ final_event_id = result.params.get("event_id", event_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
)
- return delete_result
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this event.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _calendar_types = [
+ SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
+ ]
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Calendar connector is invalid or has been disconnected.",
+ }
+
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
+ )
+
+ is_composio_calendar = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ )
+ if is_composio_calendar:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
+ else:
+ config_data = dict(connector.config)
+
+ from app.config import config as app_config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and app_config.SECRET_KEY:
+ token_encryption = TokenEncryption(app_config.SECRET_KEY)
+ for key in ("token", "refresh_token", "client_secret"):
+ if config_data.get(key):
+ config_data[key] = token_encryption.decrypt_token(
+ config_data[key]
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ try:
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_DELETE_EVENT",
+ params={
+ "calendar_id": "primary",
+ "event_id": final_event_id,
+ },
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .delete(calendarId="primary", eventId=final_event_id)
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(f"Calendar event deleted: event_id={final_event_id}")
+
+ delete_result: dict[str, Any] = {
+ "status": "success",
+ "event_id": final_event_id,
+ "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ delete_result["warning"] = (
+ f"Event deleted, but failed to remove from knowledge base: {e!s}"
+ )
+
+ delete_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ delete_result["message"] = (
+ f"{delete_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return delete_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
index dc6adb822..e5f18f675 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
@@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
]
+def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
+ if "T" in value:
+ return value
+ time = "23:59:59" if is_end else "00:00:00"
+ return f"{value}T{time}Z"
+
+
+def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ events = []
+ for ev in events_raw:
+ start = ev.get("start", {})
+ end = ev.get("end", {})
+ attendees_raw = ev.get("attendees", [])
+ events.append(
+ {
+ "event_id": ev.get("id"),
+ "summary": ev.get("summary", "No Title"),
+ "start": start.get("dateTime") or start.get("date", ""),
+ "end": end.get("dateTime") or end.get("date", ""),
+ "location": ev.get("location", ""),
+ "description": ev.get("description", ""),
+ "html_link": ev.get("htmlLink", ""),
+ "attendees": [a.get("email", "") for a in attendees_raw[:10]],
+ "status": ev.get("status", ""),
+ }
+ )
+ return events
+
+
def create_search_calendar_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the search_calendar_events tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured search_calendar_events tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def search_calendar_events(
start_date: str,
@@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Calendar tool not properly configured.",
@@ -47,76 +93,85 @@ def create_search_calendar_events_tool(
max_results = min(max_results, 50)
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
- }
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
+ }
- creds = _build_credentials(connector)
-
- from app.connectors.google_calendar_connector import GoogleCalendarConnector
-
- cal = GoogleCalendarConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
-
- events_raw, error = await cal.get_all_primary_calendar_events(
- start_date=start_date,
- end_date=end_date,
- max_results=max_results,
- )
-
- if error:
if (
- "re-authenticate" in error.lower()
- or "authentication failed" in error.lower()
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
- return {
- "status": "auth_error",
- "message": error,
- "connector_type": "google_calendar",
- }
- if "no events found" in error.lower():
- return {
- "status": "success",
- "events": [],
- "total": 0,
- "message": error,
- }
- return {"status": "error", "message": error}
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
- events = []
- for ev in events_raw:
- start = ev.get("start", {})
- end = ev.get("end", {})
- attendees_raw = ev.get("attendees", [])
- events.append(
- {
- "event_id": ev.get("id"),
- "summary": ev.get("summary", "No Title"),
- "start": start.get("dateTime") or start.get("date", ""),
- "end": end.get("dateTime") or end.get("date", ""),
- "location": ev.get("location", ""),
- "description": ev.get("description", ""),
- "html_link": ev.get("htmlLink", ""),
- "attendees": [a.get("email", "") for a in attendees_raw[:10]],
- "status": ev.get("status", ""),
- }
- )
+ from app.services.composio_service import ComposioService
- return {"status": "success", "events": events, "total": len(events)}
+ events_raw, error = await ComposioService().get_calendar_events(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ time_min=_to_calendar_boundary(start_date, is_end=False),
+ time_max=_to_calendar_boundary(end_date, is_end=True),
+ max_results=max_results,
+ )
+ if not events_raw and not error:
+ error = "No events found in the specified date range."
+ else:
+ creds = _build_credentials(connector)
+
+ from app.connectors.google_calendar_connector import (
+ GoogleCalendarConnector,
+ )
+
+ cal = GoogleCalendarConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
+ )
+
+ events_raw, error = await cal.get_all_primary_calendar_events(
+ start_date=start_date,
+ end_date=end_date,
+ max_results=max_results,
+ )
+
+ if error:
+ if (
+ "re-authenticate" in error.lower()
+ or "authentication failed" in error.lower()
+ ):
+ return {
+ "status": "auth_error",
+ "message": error,
+ "connector_type": "google_calendar",
+ }
+ if "no events found" in error.lower():
+ return {
+ "status": "success",
+ "events": [],
+ "total": 0,
+ "message": error,
+ }
+ return {"status": "error", "message": error}
+
+ events = _format_calendar_events(events_raw)
+
+ return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
index 259f52bba..b8561fee6 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the update_calendar_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured update_calendar_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_calendar_event(
event_title_or_id: str,
@@ -74,272 +92,317 @@ def create_update_calendar_event_tool(
"""
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleCalendarToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, event_title_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Event not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- if context.get("auth_expired"):
- logger.warning("Google Calendar account has expired authentication")
- return {
- "status": "auth_error",
- "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_calendar",
- }
-
- event = context["event"]
- event_id = event["event_id"]
- document_id = event.get("document_id")
- connector_id_from_context = context["account"]["id"]
-
- if not event_id:
- return {
- "status": "error",
- "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
- }
-
- logger.info(
- f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
- )
- result = request_approval(
- action_type="google_calendar_event_update",
- tool_name="update_calendar_event",
- params={
- "event_id": event_id,
- "document_id": document_id,
- "connector_id": connector_id_from_context,
- "new_summary": new_summary,
- "new_start_datetime": new_start_datetime,
- "new_end_datetime": new_end_datetime,
- "new_description": new_description,
- "new_location": new_location,
- "new_attendees": new_attendees,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
- }
-
- final_event_id = result.params.get("event_id", event_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_new_summary = result.params.get("new_summary", new_summary)
- final_new_start_datetime = result.params.get(
- "new_start_datetime", new_start_datetime
- )
- final_new_end_datetime = result.params.get(
- "new_end_datetime", new_end_datetime
- )
- final_new_description = result.params.get(
- "new_description", new_description
- )
- final_new_location = result.params.get("new_location", new_location)
- final_new_attendees = result.params.get("new_attendees", new_attendees)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this event.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _calendar_types = [
- SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleCalendarToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, event_title_or_id
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Calendar connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Event not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- logger.info(
- f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
- )
+ if context.get("auth_expired"):
+ logger.warning("Google Calendar account has expired authentication")
+ return {
+ "status": "auth_error",
+ "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_calendar",
+ }
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
+ event = context["event"]
+ event_id = event["event_id"]
+ document_id = event.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not event_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this connector.",
+ "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
- else:
- config_data = dict(connector.config)
- from app.config import config as app_config
- from app.utils.oauth_security import TokenEncryption
-
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and app_config.SECRET_KEY:
- token_encryption = TokenEncryption(app_config.SECRET_KEY)
- for key in ("token", "refresh_token", "client_secret"):
- if config_data.get(key):
- config_data[key] = token_encryption.decrypt_token(
- config_data[key]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
+ )
+ result = request_approval(
+ action_type="google_calendar_event_update",
+ tool_name="update_calendar_event",
+ params={
+ "event_id": event_id,
+ "document_id": document_id,
+ "connector_id": connector_id_from_context,
+ "new_summary": new_summary,
+ "new_start_datetime": new_start_datetime,
+ "new_end_datetime": new_end_datetime,
+ "new_description": new_description,
+ "new_location": new_location,
+ "new_attendees": new_attendees,
+ },
+ context=context,
)
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
+ }
- update_body: dict[str, Any] = {}
- if final_new_summary is not None:
- update_body["summary"] = final_new_summary
- if final_new_start_datetime is not None:
- update_body["start"] = _build_time_body(
- final_new_start_datetime, context
+ final_event_id = result.params.get("event_id", event_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
)
- if final_new_end_datetime is not None:
- update_body["end"] = _build_time_body(final_new_end_datetime, context)
- if final_new_description is not None:
- update_body["description"] = final_new_description
- if final_new_location is not None:
- update_body["location"] = final_new_location
- if final_new_attendees is not None:
- update_body["attendees"] = [
- {"email": e.strip()} for e in final_new_attendees if e.strip()
+ final_new_summary = result.params.get("new_summary", new_summary)
+ final_new_start_datetime = result.params.get(
+ "new_start_datetime", new_start_datetime
+ )
+ final_new_end_datetime = result.params.get(
+ "new_end_datetime", new_end_datetime
+ )
+ final_new_description = result.params.get(
+ "new_description", new_description
+ )
+ final_new_location = result.params.get("new_location", new_location)
+ final_new_attendees = result.params.get("new_attendees", new_attendees)
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this event.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _calendar_types = [
+ SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
- if not update_body:
- return {
- "status": "error",
- "message": "No changes specified. Please provide at least one field to update.",
- }
-
- try:
- updated = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .patch(
- calendarId="primary",
- eventId=final_event_id,
- body=update_body,
- )
- .execute()
- ),
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
+ )
)
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
+ connector = result.scalars().first()
+ if not connector:
return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "error",
+ "message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
- raise
- logger.info(f"Calendar event updated: event_id={final_event_id}")
+ actual_connector_id = connector.id
- kb_message_suffix = ""
- if document_id is not None:
- try:
- from app.services.google_calendar import GoogleCalendarKBSyncService
+ logger.info(
+ f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
+ )
- kb_service = GoogleCalendarKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=document_id,
- event_id=final_event_id,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ is_composio_calendar = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ )
+ if is_composio_calendar:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
+ else:
+ config_data = dict(connector.config)
+
+ from app.config import config as app_config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and app_config.SECRET_KEY:
+ token_encryption = TokenEncryption(app_config.SECRET_KEY)
+ for key in ("token", "refresh_token", "client_secret"):
+ if config_data.get(key):
+ config_data[key] = token_encryption.decrypt_token(
+ config_data[key]
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
)
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
- )
- else:
- kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after update failed: {kb_err}")
- kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
- return {
- "status": "success",
- "event_id": final_event_id,
- "html_link": updated.get("htmlLink"),
- "message": f"Successfully updated the calendar event.{kb_message_suffix}",
- }
+ update_body: dict[str, Any] = {}
+ if final_new_summary is not None:
+ update_body["summary"] = final_new_summary
+ if final_new_start_datetime is not None:
+ update_body["start"] = _build_time_body(
+ final_new_start_datetime, context
+ )
+ if final_new_end_datetime is not None:
+ update_body["end"] = _build_time_body(
+ final_new_end_datetime, context
+ )
+ if final_new_description is not None:
+ update_body["description"] = final_new_description
+ if final_new_location is not None:
+ update_body["location"] = final_new_location
+ if final_new_attendees is not None:
+ update_body["attendees"] = [
+ {"email": e.strip()} for e in final_new_attendees if e.strip()
+ ]
+
+ if not update_body:
+ return {
+ "status": "error",
+ "message": "No changes specified. Please provide at least one field to update.",
+ }
+
+ try:
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_params: dict[str, Any] = {
+ "calendar_id": "primary",
+ "event_id": final_event_id,
+ }
+ if final_new_summary is not None:
+ composio_params["summary"] = final_new_summary
+ if final_new_start_datetime is not None:
+ composio_params["start_time"] = final_new_start_datetime
+ if final_new_end_datetime is not None:
+ composio_params["end_time"] = final_new_end_datetime
+ if final_new_description is not None:
+ composio_params["description"] = final_new_description
+ if final_new_location is not None:
+ composio_params["location"] = final_new_location
+ if final_new_attendees is not None:
+ composio_params["attendees"] = [
+ e.strip() for e in final_new_attendees if e.strip()
+ ]
+ if not _is_date_only(
+ final_new_start_datetime or final_new_end_datetime or ""
+ ):
+ composio_params["timezone"] = context.get("timezone", "UTC")
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_PATCH_EVENT",
+ params=composio_params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ updated = composio_result.get("data", {})
+ if isinstance(updated, dict):
+ updated = updated.get("data", updated)
+ if isinstance(updated, dict):
+ updated = updated.get("response_data", updated)
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ updated = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .patch(
+ calendarId="primary",
+ eventId=final_event_id,
+ body=update_body,
+ )
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(f"Calendar event updated: event_id={final_event_id}")
+
+ kb_message_suffix = ""
+ if document_id is not None:
+ try:
+ from app.services.google_calendar import (
+ GoogleCalendarKBSyncService,
+ )
+
+ kb_service = GoogleCalendarKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=document_id,
+ event_id=final_event_id,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after update failed: {kb_err}")
+ kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "event_id": final_event_id,
+ "html_link": updated.get("htmlLink"),
+ "message": f"Successfully updated the calendar event.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
index f36db8f3f..66199ca67 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
+from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
@@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_google_drive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Google Drive connector
+ user_id: User ID for fetching user-specific context
+
+ Returns:
+ Configured create_google_drive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_google_drive_file(
name: str,
@@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
f"create_google_drive_file called: name='{name}', type='{file_type}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
@@ -78,195 +98,232 @@ def create_create_google_drive_file_tool(
}
try:
- metadata_service = GoogleDriveToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleDriveToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
+ )
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Google Drive accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_drive",
- }
-
- logger.info(
- f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
- )
- result = request_approval(
- action_type="google_drive_file_creation",
- tool_name="create_google_drive_file",
- params={
- "name": name,
- "file_type": file_type,
- "content": content,
- "connector_id": None,
- "parent_folder_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
- }
-
- final_name = result.params.get("name", name)
- final_file_type = result.params.get("file_type", file_type)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get("connector_id")
- final_parent_folder_id = result.params.get("parent_folder_id")
-
- if not final_name or not final_name.strip():
- return {"status": "error", "message": "File name cannot be empty."}
-
- mime_type = _MIME_MAP.get(final_file_type)
- if not mime_type:
- return {
- "status": "error",
- "message": f"Unsupported file type '{final_file_type}'.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _drive_types = [
- SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_drive_types),
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Drive connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_drive_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
- }
- actual_connector_id = connector.id
+ return {"status": "error", "message": context["error"]}
- logger.info(
- f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
- )
-
- pre_built_creds = None
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
-
- client = GoogleDriveClient(
- session=db_session,
- connector_id=actual_connector_id,
- credentials=pre_built_creds,
- )
- try:
- created = await client.create_file(
- name=final_name,
- mime_type=mime_type,
- parent_folder_id=final_parent_folder_id,
- content=final_content,
- )
- except HttpError as http_err:
- if http_err.resp.status == 403:
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
+ "All Google Drive accounts have expired authentication"
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "auth_error",
+ "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_drive",
}
- raise
- logger.info(
- f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.google_drive import GoogleDriveKBSyncService
-
- kb_service = GoogleDriveKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- file_id=created.get("id"),
- file_name=created.get("name", final_name),
- mime_type=mime_type,
- web_view_link=created.get("webViewLink"),
- content=final_content,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ logger.info(
+ f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
+ )
+ result = request_approval(
+ action_type="google_drive_file_creation",
+ tool_name="create_google_drive_file",
+ params={
+ "name": name,
+ "file_type": file_type,
+ "content": content,
+ "connector_id": None,
+ "parent_folder_id": None,
+ },
+ context=context,
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "file_id": created.get("id"),
- "name": created.get("name"),
- "web_view_link": created.get("webViewLink"),
- "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
+ }
+
+ final_name = result.params.get("name", name)
+ final_file_type = result.params.get("file_type", file_type)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get("connector_id")
+ final_parent_folder_id = result.params.get("parent_folder_id")
+
+ if not final_name or not final_name.strip():
+ return {"status": "error", "message": "File name cannot be empty."}
+
+ mime_type = _MIME_MAP.get(final_file_type)
+ if not mime_type:
+ return {
+ "status": "error",
+ "message": f"Unsupported file type '{final_file_type}'.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _drive_types = [
+ SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
+ ]
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_drive_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Drive connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
+ else:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_drive_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
+ }
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
+ )
+
+ is_composio_drive = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
+ )
+ if is_composio_drive:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Drive connector.",
+ }
+ client = GoogleDriveClient(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+ try:
+ if is_composio_drive:
+ from app.services.composio_service import ComposioService
+
+ params: dict[str, Any] = {
+ "name": final_name,
+ "mimeType": mime_type,
+ "fields": "id,name,webViewLink,mimeType",
+ }
+ if final_parent_folder_id:
+ params["parents"] = [final_parent_folder_id]
+ if final_content:
+ params["description"] = final_content[:4096]
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLEDRIVE_CREATE_FILE",
+ params=params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ raise RuntimeError(
+ result.get("error", "Unknown Composio Drive error")
+ )
+ created = result.get("data", {})
+ if isinstance(created, dict):
+ created = created.get("data", created)
+ if isinstance(created, dict):
+ created = created.get("response_data", created)
+ if not isinstance(created, dict):
+ created = {}
+ else:
+ created = await client.create_file(
+ name=final_name,
+ mime_type=mime_type,
+ parent_folder_id=final_parent_folder_id,
+ content=final_content,
+ )
+ except HttpError as http_err:
+ if http_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(
+ f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.google_drive import GoogleDriveKBSyncService
+
+ kb_service = GoogleDriveKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ file_id=created.get("id"),
+ file_name=created.get("name", final_name),
+ mime_type=mime_type,
+ web_view_link=created.get("webViewLink"),
+ content=final_content,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "file_id": created.get("id"),
+ "name": created.get("name"),
+ "web_view_link": created.get("webViewLink"),
+ "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
index 832afff0d..b3c9240d8 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
+from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_google_drive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Google Drive connector
+ user_id: User ID for fetching user-specific context
+
+ Returns:
+ Configured delete_google_drive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_google_drive_file(
file_name: str,
@@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleDriveToolMetadataService(db_session)
- context = await metadata_service.get_trash_context(
- search_space_id, user_id, file_name
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"File not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch trash context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Google Drive account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleDriveToolMetadataService(db_session)
+ context = await metadata_service.get_trash_context(
+ search_space_id, user_id, file_name
)
- return {
- "status": "auth_error",
- "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_drive",
- }
- file = context["file"]
- file_id = file["file_id"]
- document_id = file.get("document_id")
- connector_id_from_context = context["account"]["id"]
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"File not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch trash context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if not file_id:
- return {
- "status": "error",
- "message": "File ID is missing from the indexed document. Please re-index the file and try again.",
- }
-
- logger.info(
- f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="google_drive_file_trash",
- tool_name="delete_google_drive_file",
- params={
- "file_id": file_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
- }
-
- final_file_id = result.params.get("file_id", file_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this file.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _drive_types = [
- SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_drive_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Drive connector is invalid or has been disconnected.",
- }
-
- logger.info(
- f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
- )
-
- pre_built_creds = None
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
-
- client = GoogleDriveClient(
- session=db_session,
- connector_id=connector.id,
- credentials=pre_built_creds,
- )
- try:
- await client.trash_file(file_id=final_file_id)
- except HttpError as http_err:
- if http_err.resp.status == 403:
+ account = context.get("account", {})
+ if account.get("auth_expired"):
logger.warning(
- f"Insufficient permissions for connector {connector.id}: {http_err}"
+ "Google Drive account %s has expired authentication",
+ account.get("id"),
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- if not connector.config.get("auth_expired"):
- connector.config = {
- **connector.config,
- "auth_expired": True,
- }
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- connector.id,
- exc_info=True,
- )
return {
- "status": "insufficient_permissions",
- "connector_id": connector.id,
- "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "auth_error",
+ "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_drive",
}
- raise
- logger.info(
- f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
- )
+ file = context["file"]
+ file_id = file["file_id"]
+ document_id = file.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- trash_result: dict[str, Any] = {
- "status": "success",
- "file_id": final_file_id,
- "message": f"Successfully moved '{file['name']}' to trash.",
- }
+ if not file_id:
+ return {
+ "status": "error",
+ "message": "File ID is missing from the indexed document. Please re-index the file and try again.",
+ }
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
-
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"File moved to trash, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ logger.info(
+ f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="google_drive_file_trash",
+ tool_name="delete_google_drive_file",
+ params={
+ "file_id": file_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- return trash_result
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
+ }
+
+ final_file_id = result.params.get("file_id", file_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this file.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _drive_types = [
+ SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
+ ]
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_drive_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Drive connector is invalid or has been disconnected.",
+ }
+
+ logger.info(
+ f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
+ )
+
+ is_composio_drive = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
+ )
+ if is_composio_drive:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Drive connector.",
+ }
+
+ client = GoogleDriveClient(
+ session=db_session,
+ connector_id=connector.id,
+ )
+ try:
+ if is_composio_drive:
+ from app.services.composio_service import ComposioService
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLEDRIVE_TRASH_FILE",
+ params={"file_id": final_file_id},
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ raise RuntimeError(
+ result.get("error", "Unknown Composio Drive error")
+ )
+ else:
+ await client.trash_file(file_id=final_file_id)
+ except HttpError as http_err:
+ if http_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {connector.id}: {http_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ if not connector.config.get("auth_expired"):
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": connector.id,
+ "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(
+ f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
+ )
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "file_id": final_file_id,
+ "message": f"Successfully moved '{file['name']}' to trash.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"File moved to trash, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py
index 92248c2c9..5b64929de 100644
--- a/surfsense_backend/app/agents/new_chat/tools/hitl.py
+++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py
@@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
+ "create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
index 8b40dde65..0b04f1642 100644
--- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
@@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
+from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """Factory function to create the create_jira_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits. Per-call sessions also
+ keep the request's outer transaction free of long-running Jira API
+ blocking.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Jira connector
+ user_id: User ID for fetching user-specific context
+ connector_id: Optional specific connector ID (if known)
+
+ Returns:
+ Configured create_jira_issue tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_jira_issue(
project_key: str,
@@ -49,158 +72,167 @@ def create_create_jira_issue_tool(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
- metadata_service = JiraToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected Jira accounts need re-authentication.",
- "connector_type": "jira",
- }
-
- result = request_approval(
- action_type="jira_issue_creation",
- tool_name="create_jira_issue",
- params={
- "project_key": project_key,
- "summary": summary,
- "issue_type": issue_type,
- "description": description,
- "priority": priority,
- "connector_id": connector_id,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_project_key = result.params.get("project_key", project_key)
- final_summary = result.params.get("summary", summary)
- final_issue_type = result.params.get("issue_type", issue_type)
- final_description = result.params.get("description", description)
- final_priority = result.params.get("priority", priority)
- final_connector_id = result.params.get("connector_id", connector_id)
-
- if not final_summary or not final_summary.strip():
- return {"status": "error", "message": "Issue summary cannot be empty."}
- if not final_project_key:
- return {"status": "error", "message": "A project must be selected."}
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = JiraToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- return {"status": "error", "message": "No Jira connector found."}
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
- )
+
+ if "error" in context:
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected Jira accounts need re-authentication.",
+ "connector_type": "jira",
+ }
+
+ result = request_approval(
+ action_type="jira_issue_creation",
+ tool_name="create_jira_issue",
+ params={
+ "project_key": project_key,
+ "summary": summary,
+ "issue_type": issue_type,
+ "description": description,
+ "priority": priority,
+ "connector_id": connector_id,
+ },
+ context=context,
)
- connector = result.scalars().first()
- if not connector:
+
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_project_key = result.params.get("project_key", project_key)
+ final_summary = result.params.get("summary", summary)
+ final_issue_type = result.params.get("issue_type", issue_type)
+ final_description = result.params.get("description", description)
+ final_priority = result.params.get("priority", priority)
+ final_connector_id = result.params.get("connector_id", connector_id)
+
+ if not final_summary or not final_summary.strip():
return {
"status": "error",
- "message": "Selected Jira connector is invalid.",
+ "message": "Issue summary cannot be empty.",
}
+ if not final_project_key:
+ return {"status": "error", "message": "A project must be selected."}
- try:
- jira_history = JiraHistoryConnector(
- session=db_session, connector_id=actual_connector_id
- )
- jira_client = await jira_history._get_jira_client()
- api_result = await asyncio.to_thread(
- jira_client.create_issue,
- project_key=final_project_key,
- summary=final_summary,
- issue_type=final_issue_type,
- description=final_description,
- priority=final_priority,
- )
- except Exception as api_err:
- if "status code 403" in str(api_err).lower():
- try:
- _conn = connector
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- pass
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ from sqlalchemy.future import select
- issue_key = api_result.get("key", "")
- issue_url = (
- f"{jira_history._base_url}/browse/{issue_key}"
- if jira_history._base_url and issue_key
- else ""
- )
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- kb_message_suffix = ""
- try:
- from app.services.jira import JiraKBSyncService
-
- kb_service = JiraKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- issue_id=issue_key,
- issue_identifier=issue_key,
- issue_title=final_summary,
- description=final_description,
- state="To Do",
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Jira connector found.",
+ }
+ actual_connector_id = connector.id
else:
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Jira connector is invalid.",
+ }
- return {
- "status": "success",
- "issue_key": issue_key,
- "issue_url": issue_url,
- "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
- }
+ try:
+ jira_history = JiraHistoryConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
+ jira_client = await jira_history._get_jira_client()
+ api_result = await asyncio.to_thread(
+ jira_client.create_issue,
+ project_key=final_project_key,
+ summary=final_summary,
+ issue_type=final_issue_type,
+ description=final_description,
+ priority=final_priority,
+ )
+ except Exception as api_err:
+ if "status code 403" in str(api_err).lower():
+ try:
+ _conn = connector
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ issue_key = api_result.get("key", "")
+ issue_url = (
+ f"{jira_history._base_url}/browse/{issue_key}"
+ if jira_history._base_url and issue_key
+ else ""
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.jira import JiraKBSyncService
+
+ kb_service = JiraKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ issue_id=issue_key,
+ issue_identifier=issue_key,
+ issue_title=final_summary,
+ description=final_description,
+ state="To Do",
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "issue_key": issue_key,
+ "issue_url": issue_url,
+ "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
index 6466c80ea..c41aedad9 100644
--- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
@@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
+from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """Factory function to create the delete_jira_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Jira connector
+ user_id: User ID for fetching user-specific context
+ connector_id: Optional specific connector ID (if known)
+
+ Returns:
+ Configured delete_jira_issue tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_jira_issue(
issue_title_or_key: str,
@@ -44,130 +65,136 @@ def create_delete_jira_issue_tool(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
- metadata_service = JiraToolMetadataService(db_session)
- context = await metadata_service.get_deletion_context(
- search_space_id, user_id, issue_title_or_key
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "jira",
- }
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
-
- issue_data = context["issue"]
- issue_key = issue_data["issue_id"]
- document_id = issue_data["document_id"]
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="jira_issue_deletion",
- tool_name="delete_jira_issue",
- params={
- "issue_key": issue_key,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_key = result.params.get("issue_key", issue_key)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = JiraToolMetadataService(db_session)
+ context = await metadata_service.get_deletion_context(
+ search_space_id, user_id, issue_title_or_key
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Jira connector is invalid.",
- }
- try:
- jira_history = JiraHistoryConnector(
- session=db_session, connector_id=final_connector_id
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "jira",
+ }
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
+
+ issue_data = context["issue"]
+ issue_key = issue_data["issue_id"]
+ document_id = issue_data["document_id"]
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ result = request_approval(
+ action_type="jira_issue_deletion",
+ tool_name="delete_jira_issue",
+ params={
+ "issue_key": issue_key,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- jira_client = await jira_history._get_jira_client()
- await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
- except Exception as api_err:
- if "status code 403" in str(api_err).lower():
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- raise
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
+ final_issue_key = result.params.get("issue_key", issue_key)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this issue.",
+ }
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
)
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Jira connector is invalid.",
+ }
- message = f"Jira issue {final_issue_key} deleted successfully."
- if deleted_from_kb:
- message += " Also removed from the knowledge base."
+ try:
+ jira_history = JiraHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ jira_client = await jira_history._get_jira_client()
+ await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
+ except Exception as api_err:
+ if "status code 403" in str(api_err).lower():
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- return {
- "status": "success",
- "issue_key": final_issue_key,
- "deleted_from_kb": deleted_from_kb,
- "message": message,
- }
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+
+ message = f"Jira issue {final_issue_key} deleted successfully."
+ if deleted_from_kb:
+ message += " Also removed from the knowledge base."
+
+ return {
+ "status": "success",
+ "issue_key": final_issue_key,
+ "deleted_from_kb": deleted_from_kb,
+ "message": message,
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
index f6e586a2e..0fd7b28b3 100644
--- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
@@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
+from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """Factory function to create the update_jira_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Jira connector
+ user_id: User ID for fetching user-specific context
+ connector_id: Optional specific connector ID (if known)
+
+ Returns:
+ Configured update_jira_issue tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_jira_issue(
issue_title_or_key: str,
@@ -48,169 +69,177 @@ def create_update_jira_issue_tool(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
- metadata_service = JiraToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, issue_title_or_key
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "jira",
- }
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
-
- issue_data = context["issue"]
- issue_key = issue_data["issue_id"]
- document_id = issue_data.get("document_id")
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="jira_issue_update",
- tool_name="update_jira_issue",
- params={
- "issue_key": issue_key,
- "document_id": document_id,
- "new_summary": new_summary,
- "new_description": new_description,
- "new_priority": new_priority,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_key = result.params.get("issue_key", issue_key)
- final_summary = result.params.get("new_summary", new_summary)
- final_description = result.params.get("new_description", new_description)
- final_priority = result.params.get("new_priority", new_priority)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_document_id = result.params.get("document_id", document_id)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = JiraToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, issue_title_or_key
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Jira connector is invalid.",
- }
- fields: dict[str, Any] = {}
- if final_summary:
- fields["summary"] = final_summary
- if final_description is not None:
- fields["description"] = {
- "type": "doc",
- "version": 1,
- "content": [
- {
- "type": "paragraph",
- "content": [{"type": "text", "text": final_description}],
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "jira",
}
- ],
- }
- if final_priority:
- fields["priority"] = {"name": final_priority}
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
- if not fields:
- return {"status": "error", "message": "No changes specified."}
+ issue_data = context["issue"]
+ issue_key = issue_data["issue_id"]
+ document_id = issue_data.get("document_id")
+ connector_id_from_context = context.get("account", {}).get("id")
- try:
- jira_history = JiraHistoryConnector(
- session=db_session, connector_id=final_connector_id
+ result = request_approval(
+ action_type="jira_issue_update",
+ tool_name="update_jira_issue",
+ params={
+ "issue_key": issue_key,
+ "document_id": document_id,
+ "new_summary": new_summary,
+ "new_description": new_description,
+ "new_priority": new_priority,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
)
- jira_client = await jira_history._get_jira_client()
- await asyncio.to_thread(
- jira_client.update_issue, final_issue_key, fields
- )
- except Exception as api_err:
- if "status code 403" in str(api_err).lower():
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- raise
- issue_url = (
- f"{jira_history._base_url}/browse/{final_issue_key}"
- if jira_history._base_url and final_issue_key
- else ""
- )
+ final_issue_key = result.params.get("issue_key", issue_key)
+ final_summary = result.params.get("new_summary", new_summary)
+ final_description = result.params.get(
+ "new_description", new_description
+ )
+ final_priority = result.params.get("new_priority", new_priority)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_document_id = result.params.get("document_id", document_id)
- kb_message_suffix = ""
- if final_document_id:
- try:
- from app.services.jira import JiraKBSyncService
+ from sqlalchemy.future import select
- kb_service = JiraKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=final_document_id,
- issue_id=final_issue_key,
- user_id=user_id,
- search_space_id=search_space_id,
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this issue.",
+ }
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
)
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Jira connector is invalid.",
+ }
+
+ fields: dict[str, Any] = {}
+ if final_summary:
+ fields["summary"] = final_summary
+ if final_description is not None:
+ fields["description"] = {
+ "type": "doc",
+ "version": 1,
+ "content": [
+ {
+ "type": "paragraph",
+ "content": [
+ {"type": "text", "text": final_description}
+ ],
+ }
+ ],
+ }
+ if final_priority:
+ fields["priority"] = {"name": final_priority}
+
+ if not fields:
+ return {"status": "error", "message": "No changes specified."}
+
+ try:
+ jira_history = JiraHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ jira_client = await jira_history._get_jira_client()
+ await asyncio.to_thread(
+ jira_client.update_issue, final_issue_key, fields
+ )
+ except Exception as api_err:
+ if "status code 403" in str(api_err).lower():
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ issue_url = (
+ f"{jira_history._base_url}/browse/{final_issue_key}"
+ if jira_history._base_url and final_issue_key
+ else ""
+ )
+
+ kb_message_suffix = ""
+ if final_document_id:
+ try:
+ from app.services.jira import JiraKBSyncService
+
+ kb_service = JiraKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=final_document_id,
+ issue_id=final_issue_key,
+ user_id=user_id,
+ search_space_id=search_space_id,
)
- else:
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = (
+ " The knowledge base will be updated in the next sync."
+ )
+ except Exception as kb_err:
+ logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
- except Exception as kb_err:
- logger.warning(f"KB sync after update failed: {kb_err}")
- kb_message_suffix = (
- " The knowledge base will be updated in the next sync."
- )
- return {
- "status": "success",
- "issue_key": final_issue_key,
- "issue_url": issue_url,
- "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
- }
+ return {
+ "status": "success",
+ "issue_key": final_issue_key,
+ "issue_url": issue_url,
+ "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py
index ff254e133..f897bee7a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
+from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
- """
- Factory function to create the create_linear_issue tool.
+ """Factory function to create the create_linear_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
Args:
- db_session: Database session for accessing the Linear connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
Returns:
Configured create_linear_issue tool
"""
+ del db_session # per-call session — see docstring
@tool
async def create_linear_issue(
@@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
"""
logger.info(f"create_linear_issue called: title='{title}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@@ -75,160 +83,170 @@ def create_create_linear_issue_tool(
}
try:
- metadata_service = LinearToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- workspaces = context.get("workspaces", [])
- if workspaces and all(w.get("auth_expired") for w in workspaces):
- logger.warning("All Linear accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "linear",
- }
-
- logger.info(f"Requesting approval for creating Linear issue: '{title}'")
- result = request_approval(
- action_type="linear_issue_creation",
- tool_name="create_linear_issue",
- params={
- "title": title,
- "description": description,
- "team_id": None,
- "state_id": None,
- "assignee_id": None,
- "priority": None,
- "label_ids": [],
- "connector_id": connector_id,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Linear issue creation rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_title = result.params.get("title", title)
- final_description = result.params.get("description", description)
- final_team_id = result.params.get("team_id")
- final_state_id = result.params.get("state_id")
- final_assignee_id = result.params.get("assignee_id")
- final_priority = result.params.get("priority")
- final_label_ids = result.params.get("label_ids") or []
- final_connector_id = result.params.get("connector_id", connector_id)
-
- if not final_title or not final_title.strip():
- logger.error("Title is empty or contains only whitespace")
- return {"status": "error", "message": "Issue title cannot be empty."}
- if not final_team_id:
- return {
- "status": "error",
- "message": "A team must be selected to create an issue.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = LinearToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
+
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
+ )
+ return {"status": "error", "message": context["error"]}
+
+ workspaces = context.get("workspaces", [])
+ if workspaces and all(w.get("auth_expired") for w in workspaces):
+ logger.warning("All Linear accounts have expired authentication")
+ return {
+ "status": "auth_error",
+ "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "linear",
+ }
+
+ logger.info(f"Requesting approval for creating Linear issue: '{title}'")
+ result = request_approval(
+ action_type="linear_issue_creation",
+ tool_name="create_linear_issue",
+ params={
+ "title": title,
+ "description": description,
+ "team_id": None,
+ "state_id": None,
+ "assignee_id": None,
+ "priority": None,
+ "label_ids": [],
+ "connector_id": connector_id,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Linear issue creation rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_title = result.params.get("title", title)
+ final_description = result.params.get("description", description)
+ final_team_id = result.params.get("team_id")
+ final_state_id = result.params.get("state_id")
+ final_assignee_id = result.params.get("assignee_id")
+ final_priority = result.params.get("priority")
+ final_label_ids = result.params.get("label_ids") or []
+ final_connector_id = result.params.get("connector_id", connector_id)
+
+ if not final_title or not final_title.strip():
+ logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
- "message": "No Linear connector found. Please connect Linear in your workspace settings.",
+ "message": "Issue title cannot be empty.",
}
- actual_connector_id = connector.id
- logger.info(f"Found Linear connector: id={actual_connector_id}")
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
- )
- )
- connector = result.scalars().first()
- if not connector:
+ if not final_team_id:
return {
"status": "error",
- "message": "Selected Linear connector is invalid or has been disconnected.",
+ "message": "A team must be selected to create an issue.",
}
- logger.info(f"Validated Linear connector: id={actual_connector_id}")
- logger.info(
- f"Creating Linear issue with final params: title='{final_title}'"
- )
- linear_client = LinearConnector(
- session=db_session, connector_id=actual_connector_id
- )
- result = await linear_client.create_issue(
- team_id=final_team_id,
- title=final_title,
- description=final_description,
- state_id=final_state_id,
- assignee_id=final_assignee_id,
- priority=final_priority,
- label_ids=final_label_ids if final_label_ids else None,
- )
+ from sqlalchemy.future import select
- if result.get("status") == "error":
- logger.error(f"Failed to create Linear issue: {result.get('message')}")
- return {"status": "error", "message": result.get("message")}
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- logger.info(
- f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.linear import LinearKBSyncService
-
- kb_service = LinearKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- issue_id=result.get("id"),
- issue_identifier=result.get("identifier", ""),
- issue_title=result.get("title", final_title),
- issue_url=result.get("url"),
- description=final_description,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Linear connector found. Please connect Linear in your workspace settings.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Found Linear connector: id={actual_connector_id}")
else:
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Linear connector is invalid or has been disconnected.",
+ }
+ logger.info(f"Validated Linear connector: id={actual_connector_id}")
- return {
- "status": "success",
- "issue_id": result.get("id"),
- "identifier": result.get("identifier"),
- "url": result.get("url"),
- "message": (result.get("message", "") + kb_message_suffix),
- }
+ logger.info(
+ f"Creating Linear issue with final params: title='{final_title}'"
+ )
+ linear_client = LinearConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
+ result = await linear_client.create_issue(
+ team_id=final_team_id,
+ title=final_title,
+ description=final_description,
+ state_id=final_state_id,
+ assignee_id=final_assignee_id,
+ priority=final_priority,
+ label_ids=final_label_ids if final_label_ids else None,
+ )
+
+ if result.get("status") == "error":
+ logger.error(
+ f"Failed to create Linear issue: {result.get('message')}"
+ )
+ return {"status": "error", "message": result.get("message")}
+
+ logger.info(
+ f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.linear import LinearKBSyncService
+
+ kb_service = LinearKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ issue_id=result.get("id"),
+ issue_identifier=result.get("identifier", ""),
+ issue_title=result.get("title", final_title),
+ issue_url=result.get("url"),
+ description=final_description,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "issue_id": result.get("id"),
+ "identifier": result.get("identifier"),
+ "url": result.get("url"),
+ "message": (result.get("message", "") + kb_message_suffix),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py
index 29ef0cdf2..c5039a8eb 100644
--- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
+from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
- """
- Factory function to create the delete_linear_issue tool.
+ """Factory function to create the delete_linear_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
Args:
- db_session: Database session for accessing the Linear connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for finding the correct Linear connector
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
Returns:
Configured delete_linear_issue tool
"""
+ del db_session # per-call session — see docstring
@tool
async def delete_linear_issue(
@@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@@ -83,149 +91,152 @@ def create_delete_linear_issue_tool(
}
try:
- metadata_service = LinearToolMetadataService(db_session)
- context = await metadata_service.get_delete_context(
- search_space_id, user_id, issue_ref
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- logger.warning(f"Auth expired for delete context: {error_msg}")
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "linear",
- }
- if "not found" in error_msg.lower():
- logger.warning(f"Issue not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- else:
- logger.error(f"Failed to fetch delete context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- issue_id = context["issue"]["id"]
- issue_identifier = context["issue"].get("identifier", "")
- document_id = context["issue"]["document_id"]
- connector_id_from_context = context.get("workspace", {}).get("id")
-
- logger.info(
- f"Requesting approval for deleting Linear issue: '{issue_ref}' "
- f"(id={issue_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="linear_issue_deletion",
- tool_name="delete_linear_issue",
- params={
- "issue_id": issue_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Linear issue deletion rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_id = result.params.get("issue_id", issue_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- logger.info(
- f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
- f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if final_connector_id:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = LinearToolMetadataService(db_session)
+ context = await metadata_service.get_delete_context(
+ search_space_id, user_id, issue_ref
)
- connector = result.scalars().first()
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ logger.warning(f"Auth expired for delete context: {error_msg}")
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "linear",
+ }
+ if "not found" in error_msg.lower():
+ logger.warning(f"Issue not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ else:
+ logger.error(f"Failed to fetch delete context: {error_msg}")
+ return {"status": "error", "message": error_msg}
+
+ issue_id = context["issue"]["id"]
+ issue_identifier = context["issue"].get("identifier", "")
+ document_id = context["issue"]["document_id"]
+ connector_id_from_context = context.get("workspace", {}).get("id")
+
+ logger.info(
+ f"Requesting approval for deleting Linear issue: '{issue_ref}' "
+ f"(id={issue_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="linear_issue_deletion",
+ tool_name="delete_linear_issue",
+ params={
+ "issue_id": issue_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Linear issue deletion rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_issue_id = result.params.get("issue_id", issue_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ logger.info(
+ f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
+ f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if final_connector_id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
)
+ connector = result.scalars().first()
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Linear connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Validated Linear connector: id={actual_connector_id}")
+ else:
+ logger.error("No connector found for this issue")
return {
"status": "error",
- "message": "Selected Linear connector is invalid or has been disconnected.",
+ "message": "No connector found for this issue.",
}
- actual_connector_id = connector.id
- logger.info(f"Validated Linear connector: id={actual_connector_id}")
- else:
- logger.error("No connector found for this issue")
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
- linear_client = LinearConnector(
- session=db_session, connector_id=actual_connector_id
- )
+ linear_client = LinearConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
- result = await linear_client.archive_issue(issue_id=final_issue_id)
+ result = await linear_client.archive_issue(issue_id=final_issue_id)
- logger.info(
- f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
- )
+ logger.info(
+ f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
+ )
- deleted_from_kb = False
- if (
- result.get("status") == "success"
- and final_delete_from_kb
- and document_id
- ):
- try:
- from app.db import Document
+ deleted_from_kb = False
+ if (
+ result.get("status") == "success"
+ and final_delete_from_kb
+ and document_id
+ ):
+ try:
+ from app.db import Document
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ result["warning"] = (
+ f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
)
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- result["warning"] = (
- f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
- )
- if result.get("status") == "success":
- result["deleted_from_kb"] = deleted_from_kb
- if issue_identifier:
- result["message"] = (
- f"Issue {issue_identifier} archived successfully."
- )
- if deleted_from_kb:
- result["message"] = (
- f"{result.get('message', '')} Also removed from the knowledge base."
- )
+ if result.get("status") == "success":
+ result["deleted_from_kb"] = deleted_from_kb
+ if issue_identifier:
+ result["message"] = (
+ f"Issue {issue_identifier} archived successfully."
+ )
+ if deleted_from_kb:
+ result["message"] = (
+ f"{result.get('message', '')} Also removed from the knowledge base."
+ )
- return result
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py
index f35d0dddd..d610ce2b7 100644
--- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
+from app.db import async_session_maker
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
- """
- Factory function to create the update_linear_issue tool.
+ """Factory function to create the update_linear_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
Args:
- db_session: Database session for accessing the Linear connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
Returns:
Configured update_linear_issue tool
"""
+ del db_session # per-call session — see docstring
@tool
async def update_linear_issue(
@@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
"""
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@@ -96,176 +104,177 @@ def create_update_linear_issue_tool(
}
try:
- metadata_service = LinearToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, issue_ref
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- logger.warning(f"Auth expired for update context: {error_msg}")
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "linear",
- }
- if "not found" in error_msg.lower():
- logger.warning(f"Issue not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- else:
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- issue_id = context["issue"]["id"]
- document_id = context["issue"]["document_id"]
- connector_id_from_context = context.get("workspace", {}).get("id")
-
- team = context.get("team", {})
- new_state_id = _resolve_state(team, new_state_name)
- new_assignee_id = _resolve_assignee(team, new_assignee_email)
- new_label_ids = _resolve_labels(team, new_label_names)
-
- logger.info(
- f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
- )
- result = request_approval(
- action_type="linear_issue_update",
- tool_name="update_linear_issue",
- params={
- "issue_id": issue_id,
- "document_id": document_id,
- "new_title": new_title,
- "new_description": new_description,
- "new_state_id": new_state_id,
- "new_assignee_id": new_assignee_id,
- "new_priority": new_priority,
- "new_label_ids": new_label_ids,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Linear issue update rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_id = result.params.get("issue_id", issue_id)
- final_document_id = result.params.get("document_id", document_id)
- final_new_title = result.params.get("new_title", new_title)
- final_new_description = result.params.get(
- "new_description", new_description
- )
- final_new_state_id = result.params.get("new_state_id", new_state_id)
- final_new_assignee_id = result.params.get(
- "new_assignee_id", new_assignee_id
- )
- final_new_priority = result.params.get("new_priority", new_priority)
- final_new_label_ids: list[str] | None = result.params.get(
- "new_label_ids", new_label_ids
- )
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
-
- if not final_connector_id:
- logger.error("No connector found for this issue")
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = LinearToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, issue_ref
)
- )
- connector = result.scalars().first()
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "Selected Linear connector is invalid or has been disconnected.",
- }
- logger.info(f"Validated Linear connector: id={final_connector_id}")
- logger.info(
- f"Updating Linear issue with final params: issue_id={final_issue_id}"
- )
- linear_client = LinearConnector(
- session=db_session, connector_id=final_connector_id
- )
- updated_issue = await linear_client.update_issue(
- issue_id=final_issue_id,
- title=final_new_title,
- description=final_new_description,
- state_id=final_new_state_id,
- assignee_id=final_new_assignee_id,
- priority=final_new_priority,
- label_ids=final_new_label_ids,
- )
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ logger.warning(f"Auth expired for update context: {error_msg}")
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "linear",
+ }
+ if "not found" in error_msg.lower():
+ logger.warning(f"Issue not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ else:
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if updated_issue.get("status") == "error":
- logger.error(
- f"Failed to update Linear issue: {updated_issue.get('message')}"
- )
- return {
- "status": "error",
- "message": updated_issue.get("message"),
- }
+ issue_id = context["issue"]["id"]
+ document_id = context["issue"]["document_id"]
+ connector_id_from_context = context.get("workspace", {}).get("id")
- logger.info(
- f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
- )
+ team = context.get("team", {})
+ new_state_id = _resolve_state(team, new_state_name)
+ new_assignee_id = _resolve_assignee(team, new_assignee_email)
+ new_label_ids = _resolve_labels(team, new_label_names)
- if final_document_id is not None:
logger.info(
- f"Updating knowledge base for document {final_document_id}..."
+ f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
)
- kb_service = LinearKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=final_document_id,
- issue_id=final_issue_id,
- user_id=user_id,
- search_space_id=search_space_id,
+ result = request_approval(
+ action_type="linear_issue_update",
+ tool_name="update_linear_issue",
+ params={
+ "issue_id": issue_id,
+ "document_id": document_id,
+ "new_title": new_title,
+ "new_description": new_description,
+ "new_state_id": new_state_id,
+ "new_assignee_id": new_assignee_id,
+ "new_priority": new_priority,
+ "new_label_ids": new_label_ids,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
)
- if kb_result["status"] == "success":
- logger.info(
- f"Knowledge base successfully updated for issue {final_issue_id}"
- )
- kb_message = " Your knowledge base has also been updated."
- elif kb_result["status"] == "not_indexed":
- kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
- else:
- logger.warning(
- f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
- )
- kb_message = " Your knowledge base will be updated in the next scheduled sync."
- else:
- kb_message = ""
- identifier = updated_issue.get("identifier")
- default_msg = f"Issue {identifier} updated successfully."
- return {
- "status": "success",
- "identifier": identifier,
- "url": updated_issue.get("url"),
- "message": f"{updated_issue.get('message', default_msg)}{kb_message}",
- }
+ if result.rejected:
+ logger.info("Linear issue update rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_issue_id = result.params.get("issue_id", issue_id)
+ final_document_id = result.params.get("document_id", document_id)
+ final_new_title = result.params.get("new_title", new_title)
+ final_new_description = result.params.get(
+ "new_description", new_description
+ )
+ final_new_state_id = result.params.get("new_state_id", new_state_id)
+ final_new_assignee_id = result.params.get(
+ "new_assignee_id", new_assignee_id
+ )
+ final_new_priority = result.params.get("new_priority", new_priority)
+ final_new_label_ids: list[str] | None = result.params.get(
+ "new_label_ids", new_label_ids
+ )
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+
+ if not final_connector_id:
+ logger.error("No connector found for this issue")
+ return {
+ "status": "error",
+ "message": "No connector found for this issue.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Linear connector is invalid or has been disconnected.",
+ }
+ logger.info(f"Validated Linear connector: id={final_connector_id}")
+
+ logger.info(
+ f"Updating Linear issue with final params: issue_id={final_issue_id}"
+ )
+ linear_client = LinearConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ updated_issue = await linear_client.update_issue(
+ issue_id=final_issue_id,
+ title=final_new_title,
+ description=final_new_description,
+ state_id=final_new_state_id,
+ assignee_id=final_new_assignee_id,
+ priority=final_new_priority,
+ label_ids=final_new_label_ids,
+ )
+
+ if updated_issue.get("status") == "error":
+ logger.error(
+ f"Failed to update Linear issue: {updated_issue.get('message')}"
+ )
+ return {
+ "status": "error",
+ "message": updated_issue.get("message"),
+ }
+
+ logger.info(
+ f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
+ )
+
+ if final_document_id is not None:
+ logger.info(
+ f"Updating knowledge base for document {final_document_id}..."
+ )
+ kb_service = LinearKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=final_document_id,
+ issue_id=final_issue_id,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ )
+ if kb_result["status"] == "success":
+ logger.info(
+ f"Knowledge base successfully updated for issue {final_issue_id}"
+ )
+ kb_message = " Your knowledge base has also been updated."
+ elif kb_result["status"] == "not_indexed":
+ kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
+ else:
+ logger.warning(
+ f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
+ )
+ kb_message = " Your knowledge base will be updated in the next scheduled sync."
+ else:
+ kb_message = ""
+
+ identifier = updated_issue.get("identifier")
+ default_msg = f"Issue {identifier} updated successfully."
+ return {
+ "status": "success",
+ "identifier": identifier,
+ "url": updated_issue.get("url"),
+ "message": f"{updated_issue.get('message', default_msg)}{kb_message}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
index 0a24a988f..65c177d7a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
@@ -17,6 +18,23 @@ def create_create_luma_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_luma_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_luma_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_luma_event(
name: str,
@@ -40,83 +58,86 @@ def create_create_luma_event_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
- connector = await get_luma_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Luma connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_luma_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Luma connector found."}
- result = request_approval(
- action_type="luma_create_event",
- tool_name="create_luma_event",
- params={
- "name": name,
- "start_at": start_at,
- "end_at": end_at,
- "description": description,
- "timezone": timezone,
- },
- context={"connector_id": connector.id},
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Event was not created.",
- }
-
- final_name = result.params.get("name", name)
- final_start = result.params.get("start_at", start_at)
- final_end = result.params.get("end_at", end_at)
- final_desc = result.params.get("description", description)
- final_tz = result.params.get("timezone", timezone)
-
- api_key = get_api_key(connector)
- headers = luma_headers(api_key)
-
- body: dict[str, Any] = {
- "name": final_name,
- "start_at": final_start,
- "end_at": final_end,
- "timezone": final_tz,
- }
- if final_desc:
- body["description_md"] = final_desc
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- resp = await client.post(
- f"{LUMA_API}/event/create",
- headers=headers,
- json=body,
+ result = request_approval(
+ action_type="luma_create_event",
+ tool_name="create_luma_event",
+ params={
+ "name": name,
+ "start_at": start_at,
+ "end_at": end_at,
+ "description": description,
+ "timezone": timezone,
+ },
+ context={"connector_id": connector.id},
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Luma API key is invalid.",
- "connector_type": "luma",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Luma Plus subscription required to create events via API.",
- }
- if resp.status_code not in (200, 201):
- return {
- "status": "error",
- "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Event was not created.",
+ }
- data = resp.json()
- event_id = data.get("api_id") or data.get("event", {}).get("api_id")
+ final_name = result.params.get("name", name)
+ final_start = result.params.get("start_at", start_at)
+ final_end = result.params.get("end_at", end_at)
+ final_desc = result.params.get("description", description)
+ final_tz = result.params.get("timezone", timezone)
- return {
- "status": "success",
- "event_id": event_id,
- "message": f"Event '{final_name}' created on Luma.",
- }
+ api_key = get_api_key(connector)
+ headers = luma_headers(api_key)
+
+ body: dict[str, Any] = {
+ "name": final_name,
+ "start_at": final_start,
+ "end_at": final_end,
+ "timezone": final_tz,
+ }
+ if final_desc:
+ body["description_md"] = final_desc
+
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ resp = await client.post(
+ f"{LUMA_API}/event/create",
+ headers=headers,
+ json=body,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Luma API key is invalid.",
+ "connector_type": "luma",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Luma Plus subscription required to create events via API.",
+ }
+ if resp.status_code not in (200, 201):
+ return {
+ "status": "error",
+ "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
+ }
+
+ data = resp.json()
+ event_id = data.get("api_id") or data.get("event", {}).get("api_id")
+
+ return {
+ "status": "success",
+ "event_id": event_id,
+ "message": f"Event '{final_name}' created on Luma.",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
index aec5ad220..6885c2049 100644
--- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
+++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_list_luma_events_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the list_luma_events tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured list_luma_events tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def list_luma_events(
max_results: int = 25,
@@ -28,77 +47,80 @@ def create_list_luma_events_tool(
Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50)
try:
- connector = await get_luma_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Luma connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_luma_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Luma connector found."}
- api_key = get_api_key(connector)
- headers = luma_headers(api_key)
+ api_key = get_api_key(connector)
+ headers = luma_headers(api_key)
- all_entries: list[dict] = []
- cursor = None
+ all_entries: list[dict] = []
+ cursor = None
- async with httpx.AsyncClient(timeout=20.0) as client:
- while len(all_entries) < max_results:
- params: dict[str, Any] = {
- "limit": min(100, max_results - len(all_entries))
- }
- if cursor:
- params["cursor"] = cursor
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ while len(all_entries) < max_results:
+ params: dict[str, Any] = {
+ "limit": min(100, max_results - len(all_entries))
+ }
+ if cursor:
+ params["cursor"] = cursor
- resp = await client.get(
- f"{LUMA_API}/calendar/list-events",
- headers=headers,
- params=params,
+ resp = await client.get(
+ f"{LUMA_API}/calendar/list-events",
+ headers=headers,
+ params=params,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Luma API key is invalid.",
+ "connector_type": "luma",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Luma API error: {resp.status_code}",
+ }
+
+ data = resp.json()
+ entries = data.get("entries", [])
+ if not entries:
+ break
+ all_entries.extend(entries)
+
+ next_cursor = data.get("next_cursor")
+ if not next_cursor:
+ break
+ cursor = next_cursor
+
+ events = []
+ for entry in all_entries[:max_results]:
+ ev = entry.get("event", {})
+ geo = ev.get("geo_info", {})
+ events.append(
+ {
+ "event_id": entry.get("api_id"),
+ "name": ev.get("name", "Untitled"),
+ "start_at": ev.get("start_at", ""),
+ "end_at": ev.get("end_at", ""),
+ "timezone": ev.get("timezone", ""),
+ "location": geo.get("name", ""),
+ "url": ev.get("url", ""),
+ "visibility": ev.get("visibility", ""),
+ }
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Luma API key is invalid.",
- "connector_type": "luma",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Luma API error: {resp.status_code}",
- }
-
- data = resp.json()
- entries = data.get("entries", [])
- if not entries:
- break
- all_entries.extend(entries)
-
- next_cursor = data.get("next_cursor")
- if not next_cursor:
- break
- cursor = next_cursor
-
- events = []
- for entry in all_entries[:max_results]:
- ev = entry.get("event", {})
- geo = ev.get("geo_info", {})
- events.append(
- {
- "event_id": entry.get("api_id"),
- "name": ev.get("name", "Untitled"),
- "start_at": ev.get("start_at", ""),
- "end_at": ev.get("end_at", ""),
- "timezone": ev.get("timezone", ""),
- "location": geo.get("name", ""),
- "url": ev.get("url", ""),
- "visibility": ev.get("visibility", ""),
- }
- )
-
- return {"status": "success", "events": events, "total": len(events)}
+ return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py
index b37a9d617..a8484e9c0 100644
--- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_read_luma_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_luma_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_luma_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event.
@@ -26,60 +45,63 @@ def create_read_luma_event_tool(
Dictionary with status and full event details including
description, attendees count, meeting URL.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
- connector = await get_luma_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Luma connector found."}
-
- api_key = get_api_key(connector)
- headers = luma_headers(api_key)
-
- async with httpx.AsyncClient(timeout=15.0) as client:
- resp = await client.get(
- f"{LUMA_API}/events/{event_id}",
- headers=headers,
+ async with async_session_maker() as db_session:
+ connector = await get_luma_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Luma connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Luma API key is invalid.",
- "connector_type": "luma",
- }
- if resp.status_code == 404:
- return {
- "status": "not_found",
- "message": f"Event '{event_id}' not found.",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Luma API error: {resp.status_code}",
+ api_key = get_api_key(connector)
+ headers = luma_headers(api_key)
+
+ async with httpx.AsyncClient(timeout=15.0) as client:
+ resp = await client.get(
+ f"{LUMA_API}/events/{event_id}",
+ headers=headers,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Luma API key is invalid.",
+ "connector_type": "luma",
+ }
+ if resp.status_code == 404:
+ return {
+ "status": "not_found",
+ "message": f"Event '{event_id}' not found.",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Luma API error: {resp.status_code}",
+ }
+
+ data = resp.json()
+ ev = data.get("event", data)
+ geo = ev.get("geo_info", {})
+
+ event_detail = {
+ "event_id": event_id,
+ "name": ev.get("name", ""),
+ "description": ev.get("description", ""),
+ "start_at": ev.get("start_at", ""),
+ "end_at": ev.get("end_at", ""),
+ "timezone": ev.get("timezone", ""),
+ "location_name": geo.get("name", ""),
+ "address": geo.get("address", ""),
+ "url": ev.get("url", ""),
+ "meeting_url": ev.get("meeting_url", ""),
+ "visibility": ev.get("visibility", ""),
+ "cover_url": ev.get("cover_url", ""),
}
- data = resp.json()
- ev = data.get("event", data)
- geo = ev.get("geo_info", {})
-
- event_detail = {
- "event_id": event_id,
- "name": ev.get("name", ""),
- "description": ev.get("description", ""),
- "start_at": ev.get("start_at", ""),
- "end_at": ev.get("end_at", ""),
- "timezone": ev.get("timezone", ""),
- "location_name": geo.get("name", ""),
- "address": geo.get("address", ""),
- "url": ev.get("url", ""),
- "meeting_url": ev.get("meeting_url", ""),
- "visibility": ev.get("visibility", ""),
- "cover_url": ev.get("cover_url", ""),
- }
-
- return {"status": "success", "event": event_detail}
+ return {"status": "success", "event": event_detail}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py
index 6efffe960..6ec95e9f0 100644
--- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
+from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
@@ -20,8 +21,17 @@ def create_create_notion_page_tool(
"""
Factory function to create the create_notion_page tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits. Per-call sessions also
+ keep the request's outer transaction free of long-running Notion API
+ blocking.
+
Args:
- db_session: Database session for accessing Notion connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +39,7 @@ def create_create_notion_page_tool(
Returns:
Configured create_notion_page tool
"""
+ del db_session # per-call session — see docstring
@tool
async def create_notion_page(
@@ -67,7 +78,7 @@ def create_create_notion_page_tool(
"""
logger.info(f"create_notion_page called: title='{title}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@@ -77,154 +88,157 @@ def create_create_notion_page_tool(
}
try:
- metadata_service = NotionToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {
- "status": "error",
- "message": context["error"],
- }
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Notion accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "notion",
- }
-
- logger.info(f"Requesting approval for creating Notion page: '{title}'")
- result = request_approval(
- action_type="notion_page_creation",
- tool_name="create_notion_page",
- params={
- "title": title,
- "content": content,
- "parent_page_id": None,
- "connector_id": connector_id,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Notion page creation rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_title = result.params.get("title", title)
- final_content = result.params.get("content", content)
- final_parent_page_id = result.params.get("parent_page_id")
- final_connector_id = result.params.get("connector_id", connector_id)
-
- if not final_title or not final_title.strip():
- logger.error("Title is empty or contains only whitespace")
- return {
- "status": "error",
- "message": "Page title cannot be empty. Please provide a valid title.",
- }
-
- logger.info(
- f"Creating Notion page with final params: title='{final_title}'"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = NotionToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- logger.warning(
- f"No Notion connector found for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "No Notion connector found. Please connect Notion in your workspace settings.",
- }
-
- actual_connector_id = connector.id
- logger.info(f"Found Notion connector: id={actual_connector_id}")
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
- )
- connector = result.scalars().first()
-
- if not connector:
+ if "error" in context:
logger.error(
- f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
+ f"Failed to fetch creation context: {context['error']}"
)
return {
"status": "error",
- "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ "message": context["error"],
}
- logger.info(f"Validated Notion connector: id={actual_connector_id}")
- notion_connector = NotionHistoryConnector(
- session=db_session,
- connector_id=actual_connector_id,
- )
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning("All Notion accounts have expired authentication")
+ return {
+ "status": "auth_error",
+ "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "notion",
+ }
- result = await notion_connector.create_page(
- title=final_title,
- content=final_content,
- parent_page_id=final_parent_page_id,
- )
- logger.info(
- f"create_page result: {result.get('status')} - {result.get('message', '')}"
- )
+ logger.info(f"Requesting approval for creating Notion page: '{title}'")
+ result = request_approval(
+ action_type="notion_page_creation",
+ tool_name="create_notion_page",
+ params={
+ "title": title,
+ "content": content,
+ "parent_page_id": None,
+ "connector_id": connector_id,
+ },
+ context=context,
+ )
- if result.get("status") == "success":
- kb_message_suffix = ""
- try:
- from app.services.notion import NotionKBSyncService
+ if result.rejected:
+ logger.info("Notion page creation rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
- kb_service = NotionKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- page_id=result.get("page_id"),
- page_title=result.get("title", final_title),
- page_url=result.get("url"),
- content=final_content,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
+ final_title = result.params.get("title", title)
+ final_content = result.params.get("content", content)
+ final_parent_page_id = result.params.get("parent_page_id")
+ final_connector_id = result.params.get("connector_id", connector_id)
+
+ if not final_title or not final_title.strip():
+ logger.error("Title is empty or contains only whitespace")
+ return {
+ "status": "error",
+ "message": "Page title cannot be empty. Please provide a valid title.",
+ }
+
+ logger.info(
+ f"Creating Notion page with final params: title='{final_title}'"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
)
- else:
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.warning(
+ f"No Notion connector found for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "No Notion connector found. Please connect Notion in your workspace settings.",
+ }
+
+ actual_connector_id = connector.id
+ logger.info(f"Found Notion connector: id={actual_connector_id}")
+ else:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ }
+ logger.info(f"Validated Notion connector: id={actual_connector_id}")
+
+ notion_connector = NotionHistoryConnector(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+
+ result = await notion_connector.create_page(
+ title=final_title,
+ content=final_content,
+ parent_page_id=final_parent_page_id,
+ )
+ logger.info(
+ f"create_page result: {result.get('status')} - {result.get('message', '')}"
+ )
+
+ if result.get("status") == "success":
+ kb_message_suffix = ""
+ try:
+ from app.services.notion import NotionKBSyncService
+
+ kb_service = NotionKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ page_id=result.get("page_id"),
+ page_title=result.get("title", final_title),
+ page_url=result.get("url"),
+ content=final_content,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
- result["message"] = result.get("message", "") + kb_message_suffix
+ result["message"] = result.get("message", "") + kb_message_suffix
- return result
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py
index 07f7583d2..7b85da4c2 100644
--- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
+from app.db import async_session_maker
from app.services.notion.tool_metadata_service import NotionToolMetadataService
logger = logging.getLogger(__name__)
@@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
"""
Factory function to create the delete_notion_page tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
Args:
- db_session: Database session for accessing Notion connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for finding the correct Notion connector
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
Returns:
Configured delete_notion_page tool
"""
+ del db_session # per-call session — see docstring
@tool
async def delete_notion_page(
@@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@@ -73,164 +81,167 @@ def create_delete_notion_page_tool(
}
try:
- # Get page context (page_id, account, title) from indexed data
- metadata_service = NotionToolMetadataService(db_session)
- context = await metadata_service.get_delete_context(
- search_space_id, user_id, page_title
- )
-
- if "error" in context:
- error_msg = context["error"]
- # Check if it's a "not found" error (softer handling for LLM)
- if "not found" in error_msg.lower():
- logger.warning(f"Page not found: {error_msg}")
- return {
- "status": "not_found",
- "message": error_msg,
- }
- else:
- logger.error(f"Failed to fetch delete context: {error_msg}")
- return {
- "status": "error",
- "message": error_msg,
- }
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Notion account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ # Get page context (page_id, account, title) from indexed data
+ metadata_service = NotionToolMetadataService(db_session)
+ context = await metadata_service.get_delete_context(
+ search_space_id, user_id, page_title
)
- return {
- "status": "auth_error",
- "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
- }
- page_id = context.get("page_id")
- connector_id_from_context = account.get("id")
- document_id = context.get("document_id")
-
- logger.info(
- f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
- )
-
- result = request_approval(
- action_type="notion_page_deletion",
- tool_name="delete_notion_page",
- params={
- "page_id": page_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Notion page deletion rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- logger.info(
- f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- # Validate the connector
- if final_connector_id:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
- )
- connector = result.scalars().first()
-
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
- }
- actual_connector_id = connector.id
- logger.info(f"Validated Notion connector: id={actual_connector_id}")
- else:
- logger.error("No connector found for this page")
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- # Create connector instance
- notion_connector = NotionHistoryConnector(
- session=db_session,
- connector_id=actual_connector_id,
- )
-
- # Delete the page from Notion
- result = await notion_connector.delete_page(page_id=final_page_id)
- logger.info(
- f"delete_page result: {result.get('status')} - {result.get('message', '')}"
- )
-
- # If deletion was successful and user wants to delete from KB
- deleted_from_kb = False
- if (
- result.get("status") == "success"
- and final_delete_from_kb
- and document_id
- ):
- try:
- from sqlalchemy.future import select
-
- from app.db import Document
-
- # Get the document
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
-
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
+ if "error" in context:
+ error_msg = context["error"]
+ # Check if it's a "not found" error (softer handling for LLM)
+ if "not found" in error_msg.lower():
+ logger.warning(f"Page not found: {error_msg}")
+ return {
+ "status": "not_found",
+ "message": error_msg,
+ }
else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- result["warning"] = (
- f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
- )
+ logger.error(f"Failed to fetch delete context: {error_msg}")
+ return {
+ "status": "error",
+ "message": error_msg,
+ }
- # Update result with KB deletion status
- if result.get("status") == "success":
- result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- result["message"] = (
- f"{result.get('message', '')} (also removed from knowledge base)"
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Notion account %s has expired authentication",
+ account.get("id"),
)
+ return {
+ "status": "auth_error",
+ "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
+ }
- return result
+ page_id = context.get("page_id")
+ connector_id_from_context = account.get("id")
+ document_id = context.get("document_id")
+
+ logger.info(
+ f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
+ )
+
+ result = request_approval(
+ action_type="notion_page_deletion",
+ tool_name="delete_notion_page",
+ params={
+ "page_id": page_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Notion page deletion rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_page_id = result.params.get("page_id", page_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ logger.info(
+ f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ # Validate the connector
+ if final_connector_id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Validated Notion connector: id={actual_connector_id}")
+ else:
+ logger.error("No connector found for this page")
+ return {
+ "status": "error",
+ "message": "No connector found for this page.",
+ }
+
+ # Create connector instance
+ notion_connector = NotionHistoryConnector(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+
+ # Delete the page from Notion
+ result = await notion_connector.delete_page(page_id=final_page_id)
+ logger.info(
+ f"delete_page result: {result.get('status')} - {result.get('message', '')}"
+ )
+
+ # If deletion was successful and user wants to delete from KB
+ deleted_from_kb = False
+ if (
+ result.get("status") == "success"
+ and final_delete_from_kb
+ and document_id
+ ):
+ try:
+ from sqlalchemy.future import select
+
+ from app.db import Document
+
+ # Get the document
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ result["warning"] = (
+ f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
+ )
+
+ # Update result with KB deletion status
+ if result.get("status") == "success":
+ result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ result["message"] = (
+ f"{result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py
index 85c08177c..df757476a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
+from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
@@ -20,8 +21,14 @@ def create_update_notion_page_tool(
"""
Factory function to create the update_notion_page tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache (see
+ ``create_create_notion_page_tool`` for the full rationale).
+
Args:
- db_session: Database session for accessing Notion connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_update_notion_page_tool(
Returns:
Configured update_notion_page tool
"""
+ del db_session # per-call session — see docstring
@tool
async def update_notion_page(
@@ -71,7 +79,7 @@ def create_update_notion_page_tool(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@@ -88,152 +96,155 @@ def create_update_notion_page_tool(
}
try:
- metadata_service = NotionToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, page_title
- )
-
- if "error" in context:
- error_msg = context["error"]
- # Check if it's a "not found" error (softer handling for LLM)
- if "not found" in error_msg.lower():
- logger.warning(f"Page not found: {error_msg}")
- return {
- "status": "not_found",
- "message": error_msg,
- }
- else:
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {
- "status": "error",
- "message": error_msg,
- }
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Notion account %s has expired authentication",
- account.get("id"),
- )
- return {
- "status": "auth_error",
- "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
- }
-
- page_id = context.get("page_id")
- document_id = context.get("document_id")
- connector_id_from_context = context.get("account", {}).get("id")
-
- logger.info(
- f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
- )
- result = request_approval(
- action_type="notion_page_update",
- tool_name="update_notion_page",
- params={
- "page_id": page_id,
- "content": content,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Notion page update rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
-
- logger.info(
- f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if final_connector_id:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
- )
- connector = result.scalars().first()
-
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
- }
- actual_connector_id = connector.id
- logger.info(f"Validated Notion connector: id={actual_connector_id}")
- else:
- logger.error("No connector found for this page")
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- notion_connector = NotionHistoryConnector(
- session=db_session,
- connector_id=actual_connector_id,
- )
-
- result = await notion_connector.update_page(
- page_id=final_page_id,
- content=final_content,
- )
- logger.info(
- f"update_page result: {result.get('status')} - {result.get('message', '')}"
- )
-
- if result.get("status") == "success" and document_id is not None:
- from app.services.notion import NotionKBSyncService
-
- logger.info(f"Updating knowledge base for document {document_id}...")
- kb_service = NotionKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=document_id,
- appended_content=final_content,
- user_id=user_id,
- search_space_id=search_space_id,
- appended_block_ids=result.get("appended_block_ids"),
+ async with async_session_maker() as db_session:
+ metadata_service = NotionToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, page_title
)
- if kb_result["status"] == "success":
- result["message"] = (
- f"{result['message']}. Your knowledge base has also been updated."
- )
- logger.info(
- f"Knowledge base successfully updated for page {final_page_id}"
- )
- elif kb_result["status"] == "not_indexed":
- result["message"] = (
- f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
- )
- else:
- result["message"] = (
- f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
- )
+ if "error" in context:
+ error_msg = context["error"]
+ # Check if it's a "not found" error (softer handling for LLM)
+ if "not found" in error_msg.lower():
+ logger.warning(f"Page not found: {error_msg}")
+ return {
+ "status": "not_found",
+ "message": error_msg,
+ }
+ else:
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {
+ "status": "error",
+ "message": error_msg,
+ }
+
+ account = context.get("account", {})
+ if account.get("auth_expired"):
logger.warning(
- f"KB update failed for page {final_page_id}: {kb_result['message']}"
+ "Notion account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
+ }
+
+ page_id = context.get("page_id")
+ document_id = context.get("document_id")
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ logger.info(
+ f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
+ )
+ result = request_approval(
+ action_type="notion_page_update",
+ tool_name="update_notion_page",
+ params={
+ "page_id": page_id,
+ "content": content,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Notion page update rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_page_id = result.params.get("page_id", page_id)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+
+ logger.info(
+ f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if final_connector_id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Validated Notion connector: id={actual_connector_id}")
+ else:
+ logger.error("No connector found for this page")
+ return {
+ "status": "error",
+ "message": "No connector found for this page.",
+ }
+
+ notion_connector = NotionHistoryConnector(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+
+ result = await notion_connector.update_page(
+ page_id=final_page_id,
+ content=final_content,
+ )
+ logger.info(
+ f"update_page result: {result.get('status')} - {result.get('message', '')}"
+ )
+
+ if result.get("status") == "success" and document_id is not None:
+ from app.services.notion import NotionKBSyncService
+
+ logger.info(
+ f"Updating knowledge base for document {document_id}..."
+ )
+ kb_service = NotionKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=document_id,
+ appended_content=final_content,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ appended_block_ids=result.get("appended_block_ids"),
)
- return result
+ if kb_result["status"] == "success":
+ result["message"] = (
+ f"{result['message']}. Your knowledge base has also been updated."
+ )
+ logger.info(
+ f"Knowledge base successfully updated for page {final_page_id}"
+ )
+ elif kb_result["status"] == "not_indexed":
+ result["message"] = (
+ f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
+ )
+ else:
+ result["message"] = (
+ f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
+ )
+ logger.warning(
+ f"KB update failed for page {final_page_id}: {kb_result['message']}"
+ )
+
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py
index 21272e01d..5f199a41b 100644
--- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py
@@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_onedrive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_onedrive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_onedrive_file(
name: str,
@@ -70,173 +87,178 @@ def create_create_onedrive_file_tool(
"""
logger.info(f"create_onedrive_file called: name='{name}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
- )
- )
- connectors = result.scalars().all()
-
- if not connectors:
- return {
- "status": "error",
- "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
- }
-
- accounts = []
- for c in connectors:
- cfg = c.config or {}
- accounts.append(
- {
- "id": c.id,
- "name": c.name,
- "user_email": cfg.get("user_email"),
- "auth_expired": cfg.get("auth_expired", False),
- }
- )
-
- if all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected OneDrive accounts need re-authentication.",
- "connector_type": "onedrive",
- }
-
- parent_folders: dict[int, list[dict[str, str]]] = {}
- for acc in accounts:
- cid = acc["id"]
- if acc.get("auth_expired"):
- parent_folders[cid] = []
- continue
- try:
- client = OneDriveClient(session=db_session, connector_id=cid)
- items, err = await client.list_children("root")
- if err:
- logger.warning(
- "Failed to list folders for connector %s: %s", cid, err
- )
- parent_folders[cid] = []
- else:
- parent_folders[cid] = [
- {"folder_id": item["id"], "name": item["name"]}
- for item in items
- if item.get("folder") is not None
- and item.get("id")
- and item.get("name")
- ]
- except Exception:
- logger.warning(
- "Error fetching folders for connector %s", cid, exc_info=True
- )
- parent_folders[cid] = []
-
- context: dict[str, Any] = {
- "accounts": accounts,
- "parent_folders": parent_folders,
- }
-
- result = request_approval(
- action_type="onedrive_file_creation",
- tool_name="create_onedrive_file",
- params={
- "name": name,
- "content": content,
- "connector_id": None,
- "parent_folder_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_name = result.params.get("name", name)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get("connector_id")
- final_parent_folder_id = result.params.get("parent_folder_id")
-
- if not final_name or not final_name.strip():
- return {"status": "error", "message": "File name cannot be empty."}
-
- final_name = _ensure_docx_extension(final_name)
-
- if final_connector_id is not None:
+ async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
- connector = result.scalars().first()
- else:
- connector = connectors[0]
+ connectors = result.scalars().all()
- if not connector:
- return {
- "status": "error",
- "message": "Selected OneDrive connector is invalid.",
+ if not connectors:
+ return {
+ "status": "error",
+ "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
+ }
+
+ accounts = []
+ for c in connectors:
+ cfg = c.config or {}
+ accounts.append(
+ {
+ "id": c.id,
+ "name": c.name,
+ "user_email": cfg.get("user_email"),
+ "auth_expired": cfg.get("auth_expired", False),
+ }
+ )
+
+ if all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected OneDrive accounts need re-authentication.",
+ "connector_type": "onedrive",
+ }
+
+ parent_folders: dict[int, list[dict[str, str]]] = {}
+ for acc in accounts:
+ cid = acc["id"]
+ if acc.get("auth_expired"):
+ parent_folders[cid] = []
+ continue
+ try:
+ client = OneDriveClient(session=db_session, connector_id=cid)
+ items, err = await client.list_children("root")
+ if err:
+ logger.warning(
+ "Failed to list folders for connector %s: %s", cid, err
+ )
+ parent_folders[cid] = []
+ else:
+ parent_folders[cid] = [
+ {"folder_id": item["id"], "name": item["name"]}
+ for item in items
+ if item.get("folder") is not None
+ and item.get("id")
+ and item.get("name")
+ ]
+ except Exception:
+ logger.warning(
+ "Error fetching folders for connector %s",
+ cid,
+ exc_info=True,
+ )
+ parent_folders[cid] = []
+
+ context: dict[str, Any] = {
+ "accounts": accounts,
+ "parent_folders": parent_folders,
}
- docx_bytes = _markdown_to_docx(final_content or "")
-
- client = OneDriveClient(session=db_session, connector_id=connector.id)
- created = await client.create_file(
- name=final_name,
- parent_id=final_parent_folder_id,
- content=docx_bytes,
- mime_type=DOCX_MIME,
- )
-
- logger.info(
- f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.onedrive import OneDriveKBSyncService
-
- kb_service = OneDriveKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- file_id=created.get("id"),
- file_name=created.get("name", final_name),
- mime_type=DOCX_MIME,
- web_url=created.get("webUrl"),
- content=final_content,
- connector_id=connector.id,
- search_space_id=search_space_id,
- user_id=user_id,
+ result = request_approval(
+ action_type="onedrive_file_creation",
+ tool_name="create_onedrive_file",
+ params={
+ "name": name,
+ "content": content,
+ "connector_id": None,
+ "parent_folder_id": None,
+ },
+ context=context,
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "file_id": created.get("id"),
- "name": created.get("name"),
- "web_url": created.get("webUrl"),
- "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_name = result.params.get("name", name)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get("connector_id")
+ final_parent_folder_id = result.params.get("parent_folder_id")
+
+ if not final_name or not final_name.strip():
+ return {"status": "error", "message": "File name cannot be empty."}
+
+ final_name = _ensure_docx_extension(final_name)
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ else:
+ connector = connectors[0]
+
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected OneDrive connector is invalid.",
+ }
+
+ docx_bytes = _markdown_to_docx(final_content or "")
+
+ client = OneDriveClient(session=db_session, connector_id=connector.id)
+ created = await client.create_file(
+ name=final_name,
+ parent_id=final_parent_folder_id,
+ content=docx_bytes,
+ mime_type=DOCX_MIME,
+ )
+
+ logger.info(
+ f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.onedrive import OneDriveKBSyncService
+
+ kb_service = OneDriveKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ file_id=created.get("id"),
+ file_name=created.get("name", final_name),
+ mime_type=DOCX_MIME,
+ web_url=created.get("webUrl"),
+ content=final_content,
+ connector_id=connector.id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "file_id": created.get("id"),
+ "name": created.get("name"),
+ "web_url": created.get("webUrl"),
+ "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py
index a7f13b5df..4857ea988 100644
--- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py
@@ -13,6 +13,7 @@ from app.db import (
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
+ async_session_maker,
)
logger = logging.getLogger(__name__)
@@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_onedrive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_onedrive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_onedrive_file(
file_name: str,
@@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool(
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
- doc_result = await db_session.execute(
- select(Document)
- .join(
- SearchSourceConnector,
- Document.connector_id == SearchSourceConnector.id,
- )
- .filter(
- and_(
- Document.search_space_id == search_space_id,
- Document.document_type == DocumentType.ONEDRIVE_FILE,
- func.lower(Document.title) == func.lower(file_name),
- SearchSourceConnector.user_id == user_id,
- )
- )
- .order_by(Document.updated_at.desc().nullslast())
- .limit(1)
- )
- document = doc_result.scalars().first()
-
- if not document:
+ async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
@@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
- func.lower(
- cast(
- Document.document_metadata["onedrive_file_name"],
- String,
- )
- )
- == func.lower(file_name),
+ func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
@@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool(
)
document = doc_result.scalars().first()
- if not document:
- return {
- "status": "not_found",
- "message": (
- f"File '{file_name}' not found in your indexed OneDrive files. "
- "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
- "or (3) the file name is different."
- ),
- }
-
- if not document.connector_id:
- return {
- "status": "error",
- "message": "Document has no associated connector.",
- }
-
- meta = document.document_metadata or {}
- file_id = meta.get("onedrive_file_id")
- document_id = document.id
-
- if not file_id:
- return {
- "status": "error",
- "message": "File ID is missing. Please re-index the file.",
- }
-
- conn_result = await db_session.execute(
- select(SearchSourceConnector).filter(
- and_(
- SearchSourceConnector.id == document.connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
+ if not document:
+ doc_result = await db_session.execute(
+ select(Document)
+ .join(
+ SearchSourceConnector,
+ Document.connector_id == SearchSourceConnector.id,
+ )
+ .filter(
+ and_(
+ Document.search_space_id == search_space_id,
+ Document.document_type == DocumentType.ONEDRIVE_FILE,
+ func.lower(
+ cast(
+ Document.document_metadata[
+ "onedrive_file_name"
+ ],
+ String,
+ )
+ )
+ == func.lower(file_name),
+ SearchSourceConnector.user_id == user_id,
+ )
+ )
+ .order_by(Document.updated_at.desc().nullslast())
+ .limit(1)
)
- )
- )
- connector = conn_result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "OneDrive connector not found or access denied.",
- }
+ document = doc_result.scalars().first()
- cfg = connector.config or {}
- if cfg.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "onedrive",
- }
+ if not document:
+ return {
+ "status": "not_found",
+ "message": (
+ f"File '{file_name}' not found in your indexed OneDrive files. "
+ "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
+ "or (3) the file name is different."
+ ),
+ }
- context = {
- "file": {
- "file_id": file_id,
- "name": file_name,
- "document_id": document_id,
- "web_url": meta.get("web_url"),
- },
- "account": {
- "id": connector.id,
- "name": connector.name,
- "user_email": cfg.get("user_email"),
- },
- }
+ if not document.connector_id:
+ return {
+ "status": "error",
+ "message": "Document has no associated connector.",
+ }
- result = request_approval(
- action_type="onedrive_file_trash",
- tool_name="delete_onedrive_file",
- params={
- "file_id": file_id,
- "connector_id": connector.id,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ meta = document.document_metadata or {}
+ file_id = meta.get("onedrive_file_id")
+ document_id = document.id
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
+ if not file_id:
+ return {
+ "status": "error",
+ "message": "File ID is missing. Please re-index the file.",
+ }
- final_file_id = result.params.get("file_id", file_id)
- final_connector_id = result.params.get("connector_id", connector.id)
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if final_connector_id != connector.id:
- result = await db_session.execute(
+ conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
- SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
@@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool(
)
)
)
- validated_connector = result.scalars().first()
- if not validated_connector:
+ connector = conn_result.scalars().first()
+ if not connector:
return {
"status": "error",
- "message": "Selected OneDrive connector is invalid or has been disconnected.",
+ "message": "OneDrive connector not found or access denied.",
}
- actual_connector_id = validated_connector.id
- else:
- actual_connector_id = connector.id
- logger.info(
- f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
- )
+ cfg = connector.config or {}
+ if cfg.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "onedrive",
+ }
- client = OneDriveClient(
- session=db_session, connector_id=actual_connector_id
- )
- await client.trash_file(final_file_id)
+ context = {
+ "file": {
+ "file_id": file_id,
+ "name": file_name,
+ "document_id": document_id,
+ "web_url": meta.get("web_url"),
+ },
+ "account": {
+ "id": connector.id,
+ "name": connector.name,
+ "user_email": cfg.get("user_email"),
+ },
+ }
- logger.info(
- f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
- )
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "file_id": final_file_id,
- "message": f"Successfully moved '{file_name}' to the recycle bin.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- doc = doc_result.scalars().first()
- if doc:
- await db_session.delete(doc)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ result = request_approval(
+ action_type="onedrive_file_trash",
+ tool_name="delete_onedrive_file",
+ params={
+ "file_id": file_id,
+ "connector_id": connector.id,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- return trash_result
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_file_id = result.params.get("file_id", file_id)
+ final_connector_id = result.params.get("connector_id", connector.id)
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ if final_connector_id != connector.id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ and_(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id
+ == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
+ )
+ )
+ )
+ validated_connector = result.scalars().first()
+ if not validated_connector:
+ return {
+ "status": "error",
+ "message": "Selected OneDrive connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = validated_connector.id
+ else:
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
+ )
+
+ client = OneDriveClient(
+ session=db_session, connector_id=actual_connector_id
+ )
+ await client.trash_file(final_file_id)
+
+ logger.info(
+ f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
+ )
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "file_id": final_file_id,
+ "message": f"Successfully moved '{file_name}' to the recycle bin.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ doc = doc_result.scalars().first()
+ if doc:
+ await db_session.delete(doc)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py
index e8bab36fd..b842d7a20 100644
--- a/surfsense_backend/app/agents/new_chat/tools/registry.py
+++ b/surfsense_backend/app/agents/new_chat/tools/registry.py
@@ -824,13 +824,22 @@ async def build_tools_async(
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
- This function exists because MCP tools require database queries to load user configs,
- while built-in tools are created synchronously from static code.
+ This function exists because MCP tools require database queries to load
+ user configs, while built-in tools are created synchronously from static
+ code.
- Alternative: We could make build_tools() itself async and always query the database,
- but that would force async everywhere even when only using built-in tools. The current
- design keeps the simple case (static tools only) synchronous while supporting dynamic
- database-loaded tools through this async wrapper.
+ Alternative: We could make build_tools() itself async and always query
+ the database, but that would force async everywhere even when only using
+ built-in tools. The current design keeps the simple case (static tools
+ only) synchronous while supporting dynamic database-loaded tools through
+ this async wrapper.
+
+ Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
+ avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
+ the event loop) are kicked off concurrently. Cold-path savings are
+ bounded by the slower of the two — typically MCP at ~200ms-1.7s —
+ so the parallelization recovers the ~50-200ms previously spent
+ serially on built-in construction.
Args:
dependencies: Dict containing all possible dependencies
@@ -843,33 +852,70 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools.
"""
+ import asyncio
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
+ can_load_mcp = (
+ include_mcp_tools
+ and "db_session" in dependencies
+ and "search_space_id" in dependencies
+ )
+
+ # Built-in tool construction is synchronous + CPU-only. Off-loop it so
+ # MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
+ # function over its inputs — safe to thread-shift.
_t0 = time.perf_counter()
- tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
+ builtin_task = asyncio.create_task(
+ asyncio.to_thread(
+ build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
+ )
+ )
+
+ mcp_task: asyncio.Task | None = None
+ if can_load_mcp:
+ mcp_task = asyncio.create_task(
+ load_mcp_tools(
+ dependencies["db_session"],
+ dependencies["search_space_id"],
+ )
+ )
+
+ # Surface failures from each task independently so a flaky MCP
+ # endpoint never poisons built-in tool registration. ``return_exceptions``
+ # gives us per-task exceptions instead of dropping the second result
+ # when the first raises.
+ if mcp_task is not None:
+ builtin_result, mcp_result = await asyncio.gather(
+ builtin_task, mcp_task, return_exceptions=True
+ )
+ else:
+ builtin_result = await builtin_task
+ mcp_result = None
+
+ if isinstance(builtin_result, BaseException):
+ raise builtin_result # built-in registration failure is non-recoverable
+ tools: list[BaseTool] = builtin_result
_perf_log.info(
- "[build_tools_async] Built-in tools in %.3fs (%d tools)",
+ "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(tools),
)
- # Load MCP tools if requested and dependencies are available
- if (
- include_mcp_tools
- and "db_session" in dependencies
- and "search_space_id" in dependencies
- ):
- try:
- _t0 = time.perf_counter()
- mcp_tools = await load_mcp_tools(
- dependencies["db_session"],
- dependencies["search_space_id"],
+ if mcp_task is not None:
+ if isinstance(mcp_result, BaseException):
+ # ``return_exceptions=True`` captures the exception out-of-band,
+ # so ``sys.exc_info()`` is empty here. Pass the captured
+ # exception via ``exc_info=`` to get a real traceback.
+ logging.error(
+ "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
)
+ else:
+ mcp_tools = mcp_result or []
_perf_log.info(
- "[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
+ "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(mcp_tools),
)
@@ -879,8 +925,6 @@ async def build_tools_async(
len(mcp_tools),
[t.name for t in mcp_tools],
)
- except Exception as e:
- logging.exception("Failed to load MCP tools: %s", e)
logging.info(
"Total tools for agent: %d — %s",
diff --git a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py
index b8b1527c7..2965f2f02 100644
--- a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py
+++ b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py
@@ -15,7 +15,7 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
+from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
from app.utils.document_converters import embed_text
@@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
"""
Factory function to create the search_surfsense_docs tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
Args:
- db_session: Database session for executing queries
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
Returns:
A configured tool function for searching Surfsense documentation
"""
+ del db_session # per-call session — see docstring
@tool
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
@@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
Returns:
Relevant documentation content formatted with chunk IDs for citations
"""
- return await search_surfsense_docs_async(
- query=query,
- db_session=db_session,
- top_k=top_k,
- )
+ async with async_session_maker() as db_session:
+ return await search_surfsense_docs_async(
+ query=query,
+ db_session=db_session,
+ top_k=top_k,
+ )
return search_surfsense_docs
diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py
index d7b000853..0fc52b5c7 100644
--- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py
+++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the list_teams_channels tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured list_teams_channels tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to.
@@ -23,63 +42,66 @@ def create_list_teams_channels_tool(
Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name).
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
- connector = await get_teams_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Teams connector found."}
-
- token = await get_access_token(db_session, connector)
- headers = {"Authorization": f"Bearer {token}"}
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- teams_resp = await client.get(
- f"{GRAPH_API}/me/joinedTeams", headers=headers
+ async with async_session_maker() as db_session:
+ connector = await get_teams_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Teams connector found."}
- if teams_resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Teams token expired. Please re-authenticate.",
- "connector_type": "teams",
- }
- if teams_resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Graph API error: {teams_resp.status_code}",
- }
+ token = await get_access_token(db_session, connector)
+ headers = {"Authorization": f"Bearer {token}"}
- teams_data = teams_resp.json().get("value", [])
- result_teams = []
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- for team in teams_data:
- team_id = team["id"]
- ch_resp = await client.get(
- f"{GRAPH_API}/teams/{team_id}/channels",
- headers=headers,
- )
- channels = []
- if ch_resp.status_code == 200:
- channels = [
- {"id": ch["id"], "name": ch.get("displayName", "")}
- for ch in ch_resp.json().get("value", [])
- ]
- result_teams.append(
- {
- "team_id": team_id,
- "team_name": team.get("displayName", ""),
- "channels": channels,
- }
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ teams_resp = await client.get(
+ f"{GRAPH_API}/me/joinedTeams", headers=headers
)
- return {
- "status": "success",
- "teams": result_teams,
- "total_teams": len(result_teams),
- }
+ if teams_resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Teams token expired. Please re-authenticate.",
+ "connector_type": "teams",
+ }
+ if teams_resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Graph API error: {teams_resp.status_code}",
+ }
+
+ teams_data = teams_resp.json().get("value", [])
+ result_teams = []
+
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ for team in teams_data:
+ team_id = team["id"]
+ ch_resp = await client.get(
+ f"{GRAPH_API}/teams/{team_id}/channels",
+ headers=headers,
+ )
+ channels = []
+ if ch_resp.status_code == 200:
+ channels = [
+ {"id": ch["id"], "name": ch.get("displayName", "")}
+ for ch in ch_resp.json().get("value", [])
+ ]
+ result_teams.append(
+ {
+ "team_id": team_id,
+ "team_name": team.get("displayName", ""),
+ "channels": channels,
+ }
+ )
+
+ return {
+ "status": "success",
+ "teams": result_teams,
+ "total_teams": len(result_teams),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py
index d24a7e4d3..0ebda021e 100644
--- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py
+++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_teams_messages tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_teams_messages tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_teams_messages(
team_id: str,
@@ -32,65 +51,68 @@ def create_read_teams_messages_tool(
Dictionary with status and a list of messages including
id, sender, content, timestamp.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50)
try:
- connector = await get_teams_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Teams connector found."}
-
- token = await get_access_token(db_session, connector)
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- resp = await client.get(
- f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
- headers={"Authorization": f"Bearer {token}"},
- params={"$top": limit},
+ async with async_session_maker() as db_session:
+ connector = await get_teams_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Teams connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Teams token expired. Please re-authenticate.",
- "connector_type": "teams",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Insufficient permissions to read this channel.",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Graph API error: {resp.status_code}",
- }
+ token = await get_access_token(db_session, connector)
- raw_msgs = resp.json().get("value", [])
- messages = []
- for m in raw_msgs:
- sender = m.get("from", {})
- user_info = sender.get("user", {}) if sender else {}
- body = m.get("body", {})
- messages.append(
- {
- "id": m.get("id"),
- "sender": user_info.get("displayName", "Unknown"),
- "content": body.get("content", ""),
- "content_type": body.get("contentType", "text"),
- "timestamp": m.get("createdDateTime", ""),
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ resp = await client.get(
+ f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
+ headers={"Authorization": f"Bearer {token}"},
+ params={"$top": limit},
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Teams token expired. Please re-authenticate.",
+ "connector_type": "teams",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Insufficient permissions to read this channel.",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Graph API error: {resp.status_code}",
}
- )
- return {
- "status": "success",
- "team_id": team_id,
- "channel_id": channel_id,
- "messages": messages,
- "total": len(messages),
- }
+ raw_msgs = resp.json().get("value", [])
+ messages = []
+ for m in raw_msgs:
+ sender = m.get("from", {})
+ user_info = sender.get("user", {}) if sender else {}
+ body = m.get("body", {})
+ messages.append(
+ {
+ "id": m.get("id"),
+ "sender": user_info.get("displayName", "Unknown"),
+ "content": body.get("content", ""),
+ "content_type": body.get("contentType", "text"),
+ "timestamp": m.get("createdDateTime", ""),
+ }
+ )
+
+ return {
+ "status": "success",
+ "team_id": team_id,
+ "channel_id": channel_id,
+ "messages": messages,
+ "total": len(messages),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py
index fd8d00870..6f40d27e1 100644
--- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py
+++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
@@ -17,6 +18,23 @@ def create_send_teams_message_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the send_teams_message tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured send_teams_message tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def send_teams_message(
team_id: str,
@@ -39,70 +57,73 @@ def create_send_teams_message_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
- connector = await get_teams_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Teams connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_teams_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Teams connector found."}
- result = request_approval(
- action_type="teams_send_message",
- tool_name="send_teams_message",
- params={
- "team_id": team_id,
- "channel_id": channel_id,
- "content": content,
- },
- context={"connector_id": connector.id},
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Message was not sent.",
- }
-
- final_content = result.params.get("content", content)
- final_team = result.params.get("team_id", team_id)
- final_channel = result.params.get("channel_id", channel_id)
-
- token = await get_access_token(db_session, connector)
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- resp = await client.post(
- f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
- headers={
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json",
+ result = request_approval(
+ action_type="teams_send_message",
+ tool_name="send_teams_message",
+ params={
+ "team_id": team_id,
+ "channel_id": channel_id,
+ "content": content,
},
- json={"body": {"content": final_content}},
+ context={"connector_id": connector.id},
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Teams token expired. Please re-authenticate.",
- "connector_type": "teams",
- }
- if resp.status_code == 403:
- return {
- "status": "insufficient_permissions",
- "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
- }
- if resp.status_code not in (200, 201):
- return {
- "status": "error",
- "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Message was not sent.",
+ }
- msg_data = resp.json()
- return {
- "status": "success",
- "message_id": msg_data.get("id"),
- "message": "Message sent to Teams channel.",
- }
+ final_content = result.params.get("content", content)
+ final_team = result.params.get("team_id", team_id)
+ final_channel = result.params.get("channel_id", channel_id)
+
+ token = await get_access_token(db_session, connector)
+
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ resp = await client.post(
+ f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
+ headers={
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ },
+ json={"body": {"content": final_content}},
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Teams token expired. Please re-authenticate.",
+ "connector_type": "teams",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "insufficient_permissions",
+ "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
+ }
+ if resp.status_code not in (200, 201):
+ return {
+ "status": "error",
+ "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
+ }
+
+ msg_data = resp.json()
+ return {
+ "status": "success",
+ "message_id": msg_data.get("id"),
+ "message": "Message sent to Teams channel.",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py
index 4128ac0dc..fbc9edbba 100644
--- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py
+++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py
@@ -26,7 +26,7 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import SearchSpace, User
+from app.db import SearchSpace, User, async_session_maker
logger = logging.getLogger(__name__)
@@ -295,6 +295,25 @@ def create_update_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
+ """Factory function to create the user-memory update tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+ The session's bound ``commit``/``rollback`` methods are captured at
+ call time, after ``async with`` has bound ``db_session`` locally.
+
+ Args:
+ user_id: ID of the user whose memory document is being updated.
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ llm: Optional LLM for the forced-rewrite path.
+
+ Returns:
+ Configured update_memory tool for the user-memory scope.
+ """
+ del db_session # per-call session — see docstring
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
@@ -311,26 +330,26 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
- result = await db_session.execute(select(User).where(User.id == uid))
- user = result.scalars().first()
- if not user:
- return {"status": "error", "message": "User not found."}
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(select(User).where(User.id == uid))
+ user = result.scalars().first()
+ if not user:
+ return {"status": "error", "message": "User not found."}
- old_memory = user.memory_md
+ old_memory = user.memory_md
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
- llm=llm,
- apply_fn=lambda content: setattr(user, "memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="memory",
- scope="user",
- )
+ return await _save_memory(
+ updated_memory=updated_memory,
+ old_memory=old_memory,
+ llm=llm,
+ apply_fn=lambda content: setattr(user, "memory_md", content),
+ commit_fn=db_session.commit,
+ rollback_fn=db_session.rollback,
+ label="memory",
+ scope="user",
+ )
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
- await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update memory: {e}",
@@ -344,6 +363,27 @@ def create_update_team_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
+ """Factory function to create the team-memory update tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+ The session's bound ``commit``/``rollback`` methods are captured at
+ call time, after ``async with`` has bound ``db_session`` locally.
+
+ Args:
+ search_space_id: ID of the search space whose team memory is being
+ updated.
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ llm: Optional LLM for the forced-rewrite path.
+
+ Returns:
+ Configured update_memory tool for the team-memory scope.
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
@@ -359,28 +399,30 @@ def create_update_team_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
- result = await db_session.execute(
- select(SearchSpace).where(SearchSpace.id == search_space_id)
- )
- space = result.scalars().first()
- if not space:
- return {"status": "error", "message": "Search space not found."}
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSpace).where(SearchSpace.id == search_space_id)
+ )
+ space = result.scalars().first()
+ if not space:
+ return {"status": "error", "message": "Search space not found."}
- old_memory = space.shared_memory_md
+ old_memory = space.shared_memory_md
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
- llm=llm,
- apply_fn=lambda content: setattr(space, "shared_memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="team memory",
- scope="team",
- )
+ return await _save_memory(
+ updated_memory=updated_memory,
+ old_memory=old_memory,
+ llm=llm,
+ apply_fn=lambda content: setattr(
+ space, "shared_memory_md", content
+ ),
+ commit_fn=db_session.commit,
+ rollback_fn=db_session.rollback,
+ label="team memory",
+ scope="team",
+ )
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
- await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update team memory: {e}",
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 016c2de42..08194e7fb 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
+ initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
@@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
OpenRouterIntegrationService.get_instance().stop_background_refresh()
+async def _warm_agent_jit_caches() -> None:
+ """Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
+
+ Why
+ ----
+ A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
+ generation chain takes 1.5-2 seconds of pure CPU on first invocation
+ inside any Python process: the graph compiler builds reducers,
+ Pydantic v2 generates and JITs validator schemas, deepagents
+ eagerly compiles its general-purpose subagent, etc. Subsequent
+ compiles in the same process pay only ~50% of that cost (the lazy
+ JIT bits are cached in module-level dicts).
+
+ Doing one throwaway compile during ``lifespan`` startup pre-pays
+ that cost so the *first real request* doesn't. We do NOT prime
+ :mod:`agent_cache` because the cache key requires real
+ ``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the
+ throwaway agent is genuinely thrown away and immediately collected.
+
+ Safety
+ ------
+ * No DB access. We construct a stub LLM (no real keys), pass an
+ empty tools list, and pass ``checkpointer=None`` so we never
+ touch Postgres.
+ * Bounded by ``asyncio.wait_for`` so a hang here can never block
+ worker startup. On any failure, we log + swallow — the worst
+ case is the first real request pays the full cold cost (i.e.
+ pre-warmup behaviour).
+ """
+ import time as _time
+
+ logger = logging.getLogger(__name__)
+ t0 = _time.perf_counter()
+ try:
+ from langchain.agents import create_agent
+ from langchain.agents.middleware import (
+ ModelCallLimitMiddleware,
+ TodoListMiddleware,
+ ToolCallLimitMiddleware,
+ )
+ from langchain_core.language_models.fake_chat_models import (
+ FakeListChatModel,
+ )
+ from langchain_core.tools import tool
+
+ from app.agents.new_chat.context import SurfSenseContextSchema
+
+ # Minimal LLM stub. ``FakeListChatModel`` satisfies
+ # ``BaseChatModel`` without any network or auth — perfect for
+ # exercising the compile path without side effects.
+ stub_llm = FakeListChatModel(responses=["warmup-response"])
+
+ # Two trivial tools with arg + return schemas — exercises the
+ # Pydantic v2 schema JIT path. Without at least one tool the
+ # graph compile skips the tool-loop bytecode generation that
+ # accounts for ~30-50% of cold compile cost.
+ @tool
+ def _warmup_tool_a(query: str, limit: int = 5) -> str:
+ """Warmup tool A — never actually invoked."""
+ return query[:limit]
+
+ @tool
+ def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
+ """Warmup tool B — never actually invoked."""
+ return {"name": name, "value": value}
+
+ # A handful of common middleware so the compile pre-pays the
+ # ``AgentMiddleware`` resolver path. These instances never run
+ # because the throwaway agent is immediately collected.
+ # ``SubAgentMiddleware`` is the single heaviest line in cold
+ # ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
+ # compile its general-purpose subagent's full inner graph),
+ # so we include it here to make sure that compile path is JIT'd.
+ warmup_middleware: list = [
+ TodoListMiddleware(),
+ ModelCallLimitMiddleware(
+ thread_limit=120, run_limit=80, exit_behavior="end"
+ ),
+ ToolCallLimitMiddleware(
+ thread_limit=300, run_limit=80, exit_behavior="continue"
+ ),
+ ]
+ try:
+ from deepagents import SubAgentMiddleware
+ from deepagents.backends import StateBackend
+ from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
+
+ gp_warmup_spec = { # type: ignore[var-annotated]
+ **GENERAL_PURPOSE_SUBAGENT,
+ "model": stub_llm,
+ "tools": [_warmup_tool_a],
+ "middleware": [TodoListMiddleware()],
+ }
+ warmup_middleware.append(
+ SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
+ )
+ except Exception:
+ # Deepagents missing/incompatible — middleware-only warmup
+ # still produces a useful (smaller) speedup.
+ logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
+
+ compiled = create_agent(
+ stub_llm,
+ tools=[_warmup_tool_a, _warmup_tool_b],
+ system_prompt="You are a warmup stub.",
+ middleware=warmup_middleware,
+ context_schema=SurfSenseContextSchema,
+ checkpointer=None,
+ )
+
+ # Touch the compiled graph's stream_channels / nodes so any
+ # remaining lazy schema work fires now instead of on first
+ # real invocation.
+ _ = list(getattr(compiled, "nodes", {}).keys())
+
+ del compiled
+ logger.info(
+ "[startup] Agent JIT warmup completed in %.3fs",
+ _time.perf_counter() - t0,
+ )
+ except Exception:
+ logger.warning(
+ "[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
+ "real request will pay the full compile cost)",
+ _time.perf_counter() - t0,
+ exc_info=True,
+ )
+
+
@asynccontextmanager
async def lifespan(app: FastAPI):
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
@@ -432,6 +562,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables()
initialize_openrouter_integration()
_start_openrouter_background_refresh()
+ initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
@@ -443,6 +574,18 @@ async def lifespan(app: FastAPI):
"Docs will be indexed on the next restart."
)
+ # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
+ # worker readiness. ``shield`` so Uvicorn cancelling startup
+ # doesn't leave half-warmed Pydantic schemas in an inconsistent
+ # state.
+ try:
+ await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
+ except (TimeoutError, Exception): # pragma: no cover - defensive
+ logging.getLogger(__name__).warning(
+ "[startup] Agent JIT warmup hit timeout/error — skipping; "
+ "first real request will pay the full compile cost."
+ )
+
log_system_snapshot("startup_complete")
yield
@@ -452,6 +595,23 @@ async def lifespan(app: FastAPI):
def registration_allowed():
+ """Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
+
+ Despite the name, this dependency does NOT only gate registration. When
+ REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
+ that could mint or refresh a session for an attacker:
+
+ * email/password ``POST /auth/register``
+ * email/password ``POST /auth/jwt/login``
+ * the Google OAuth router (``/auth/google/authorize`` and the shared
+ ``/auth/google/callback`` handles both new signups and login for
+ existing users, so flipping this off locks both)
+ * the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
+
+ Use it as a temporary "freeze all new sessions" lever during incident
+ response. It is not a way to disable signup while keeping login working;
+ for that, override ``UserManager.oauth_callback`` instead.
+ """
if not config.REGISTRATION_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
@@ -596,32 +756,45 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers
)
-app.include_router(
- fastapi_users.get_auth_router(auth_backend),
- prefix="/auth/jwt",
- tags=["auth"],
- dependencies=[Depends(rate_limit_login)],
-)
-app.include_router(
- fastapi_users.get_register_router(UserRead, UserCreate),
- prefix="/auth",
- tags=["auth"],
- dependencies=[
- Depends(rate_limit_register),
- Depends(registration_allowed), # blocks registration when disabled
- ],
-)
-app.include_router(
- fastapi_users.get_reset_password_router(),
- prefix="/auth",
- tags=["auth"],
- dependencies=[Depends(rate_limit_password_reset)],
-)
-app.include_router(
- fastapi_users.get_verify_router(UserRead),
- prefix="/auth",
- tags=["auth"],
-)
+# Password / email-based auth routers are only mounted when not running in
+# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
+# POST /auth/register reachable, which is the bypass that allowed bots to
+# create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
+if config.AUTH_TYPE != "GOOGLE":
+ app.include_router(
+ fastapi_users.get_auth_router(auth_backend),
+ prefix="/auth/jwt",
+ tags=["auth"],
+ dependencies=[
+ Depends(rate_limit_login),
+ Depends(
+ registration_allowed
+ ), # honour REGISTRATION_ENABLED kill switch on login too
+ ],
+ )
+ app.include_router(
+ fastapi_users.get_register_router(UserRead, UserCreate),
+ prefix="/auth",
+ tags=["auth"],
+ dependencies=[
+ Depends(rate_limit_register),
+ Depends(registration_allowed),
+ ],
+ )
+ app.include_router(
+ fastapi_users.get_reset_password_router(),
+ prefix="/auth",
+ tags=["auth"],
+ dependencies=[Depends(rate_limit_password_reset)],
+ )
+ app.include_router(
+ fastapi_users.get_verify_router(UserRead),
+ prefix="/auth",
+ tags=["auth"],
+ )
+
+# /users/me (read/update profile) is needed in every auth mode, so it stays
+# mounted unconditionally.
app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
@@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
),
prefix="/auth/google",
tags=["auth"],
- dependencies=[
- Depends(registration_allowed)
- ], # blocks OAuth registration when disabled
+ # REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
+ # it blocks BOTH new OAuth signups AND login of existing OAuth users
+ # (the fastapi-users OAuth router shares one callback for create+login,
+ # so this dependency closes both paths together).
+ dependencies=[Depends(registration_allowed)],
)
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
# This endpoint performs a server-side redirect instead of returning JSON
# which fixes cross-site cookie issues where browsers don't send cookies
- # set via cross-origin fetch requests on subsequent redirects
- @app.get("/auth/google/authorize-redirect", tags=["auth"])
+ # set via cross-origin fetch requests on subsequent redirects.
+ # The registration_allowed dependency mirrors the OAuth router above so
+ # the kill switch fails fast here instead of bouncing users to Google
+ # only to 403 on the callback.
+ @app.get(
+ "/auth/google/authorize-redirect",
+ tags=["auth"],
+ dependencies=[Depends(registration_allowed)],
+ )
async def google_authorize_redirect(
request: Request,
):
diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py
index 58a8b0f39..74710d5e1 100644
--- a/surfsense_backend/app/celery_app.py
+++ b/surfsense_backend/app/celery_app.py
@@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
+ initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
+ initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py
index 86482767d..f6f0c7f62 100644
--- a/surfsense_backend/app/config/__init__.py
+++ b/surfsense_backend/app/config/__init__.py
@@ -47,11 +47,37 @@ def load_global_llm_configs():
data = yaml.safe_load(f)
configs = data.get("global_llm_configs", [])
+ # Lazy import keeps the `app.config` -> `app.services` edge one-way
+ # and matches the `provider_api_base` pattern used elsewhere.
+ from app.services.provider_capabilities import derive_supports_image_input
+
seen_slugs: dict[str, int] = {}
for cfg in configs:
cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False)
+ # Capability flag: explicit YAML override always wins. When the
+ # operator has not annotated the model, defer to LiteLLM's
+ # authoritative model map (`supports_vision`) which already
+ # knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
+ # vision-capable. Unknown / unmapped models default-allow so
+ # we don't lock the user out of a freshly added third-party
+ # entry; the streaming-task safety net (driven by
+ # `is_known_text_only_chat_model`) is the only place a False
+ # actually blocks a request.
+ if "supports_image_input" not in cfg:
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = (
+ litellm_params.get("base_model")
+ if isinstance(litellm_params, dict)
+ else None
+ )
+ cfg["supports_image_input"] = derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"]
@@ -63,6 +89,27 @@ def load_global_llm_configs():
else:
seen_slugs[slug] = cfg.get("id", 0)
+ # Stamp Auto (Fastest) ranking metadata. YAML configs are always
+ # Tier A — operator-curated, locked first when premium-eligible.
+ # The OpenRouter refresh tick later re-stamps health for any cfg
+ # whose provider == "OPENROUTER" via _enrich_health.
+ try:
+ from app.services.quality_score import static_score_yaml
+
+ for cfg in configs:
+ cfg["auto_pin_tier"] = "A"
+ static_q = static_score_yaml(cfg)
+ cfg["quality_score_static"] = static_q
+ cfg["quality_score"] = static_q
+ cfg["quality_score_health"] = None
+ # YAML cfgs whose provider is OPENROUTER are also subject
+ # to health gating against their own /endpoints data — a
+ # hand-picked dead OR model is still dead. _enrich_health
+ # re-stamps health_gated for them on the next refresh tick.
+ cfg["health_gated"] = False
+ except Exception as e:
+ print(f"Warning: Failed to score global LLM configs: {e}")
+
return configs
except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}")
@@ -117,7 +164,11 @@ def load_global_image_gen_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
- return data.get("global_image_generation_configs", [])
+ configs = data.get("global_image_generation_configs", []) or []
+ for cfg in configs:
+ if isinstance(cfg, dict):
+ cfg.setdefault("billing_tier", "free")
+ return configs
except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}")
return []
@@ -132,7 +183,11 @@ def load_global_vision_llm_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
- return data.get("global_vision_llm_configs", [])
+ configs = data.get("global_vision_llm_configs", []) or []
+ for cfg in configs:
+ if isinstance(cfg, dict):
+ cfg.setdefault("billing_tier", "free")
+ return configs
except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}")
return []
@@ -194,6 +249,9 @@ def load_openrouter_integration_settings() -> dict | None:
"""
Load OpenRouter integration settings from the YAML config.
+ Emits startup warnings for deprecated keys (``billing_tier``,
+ ``anonymous_enabled``) and seeds their replacements for back-compat.
+
Returns:
dict with settings if present and enabled, None otherwise
"""
@@ -206,9 +264,40 @@ def load_openrouter_integration_settings() -> dict | None:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
settings = data.get("openrouter_integration")
- if settings and settings.get("enabled"):
- return settings
- return None
+ if not settings or not settings.get("enabled"):
+ return None
+
+ if "billing_tier" in settings:
+ print(
+ "Warning: openrouter_integration.billing_tier is deprecated; "
+ "tier is now derived per model from OpenRouter data "
+ "(':free' suffix or zero pricing). Remove this key."
+ )
+
+ if "anonymous_enabled" in settings:
+ print(
+ "Warning: openrouter_integration.anonymous_enabled is "
+ "deprecated; use anonymous_enabled_paid and/or "
+ "anonymous_enabled_free instead. Both new flags have been "
+ "seeded from the legacy value for back-compat."
+ )
+ settings.setdefault(
+ "anonymous_enabled_paid", settings["anonymous_enabled"]
+ )
+ settings.setdefault(
+ "anonymous_enabled_free", settings["anonymous_enabled"]
+ )
+
+ # Image generation + vision LLM emission are opt-in (issue L).
+ # OpenRouter's catalogue contains hundreds of image / vision
+ # capable models; auto-injecting all of them into every
+ # deployment would explode the model selector and surprise
+ # operators upgrading from prior versions. Default to False so
+ # admins must explicitly turn them on.
+ settings.setdefault("image_generation_enabled", False)
+ settings.setdefault("vision_enabled", False)
+
+ return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
return None
@@ -217,9 +306,14 @@ def load_openrouter_integration_settings() -> dict | None:
def initialize_openrouter_integration():
"""
If enabled, fetch all OpenRouter models and append them to
- config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
- Should be called BEFORE initialize_llm_router() so the router
- correctly excludes premium models from Auto mode.
+ config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
+ is derived per-model from OpenRouter's API signals (``:free`` suffix or
+ zero pricing), so free OpenRouter models correctly skip premium quota.
+
+ Should be called BEFORE initialize_llm_router(). Dynamic entries are
+ tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
+ by title-gen / sub-agent flows) remains scoped to curated YAML configs,
+ while user-facing Auto-mode thread pinning still considers them.
"""
settings = load_openrouter_integration_settings()
if not settings:
@@ -235,16 +329,70 @@ def initialize_openrouter_integration():
if new_configs:
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
+ free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
+ premium_count = sum(
+ 1 for c in new_configs if c.get("billing_tier") == "premium"
+ )
print(
f"Info: OpenRouter integration added {len(new_configs)} models "
- f"(billing_tier={settings.get('billing_tier', 'premium')})"
+ f"(free={free_count}, premium={premium_count})"
)
else:
print("Info: OpenRouter integration enabled but no models fetched")
+
+ # Image generation + vision LLM emissions are opt-in (issue L).
+ # Both reuse the catalogue already cached by ``service.initialize``
+ # so we don't make additional network calls here.
+ if settings.get("image_generation_enabled"):
+ try:
+ image_configs = service.get_image_generation_configs()
+ if image_configs:
+ config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
+ print(
+ f"Info: OpenRouter integration added {len(image_configs)} "
+ f"image-generation models"
+ )
+ except Exception as e:
+ print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
+
+ if settings.get("vision_enabled"):
+ try:
+ vision_configs = service.get_vision_llm_configs()
+ if vision_configs:
+ config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
+ print(
+ f"Info: OpenRouter integration added {len(vision_configs)} "
+ f"vision LLM models"
+ )
+ except Exception as e:
+ print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
+def initialize_pricing_registration():
+ """
+ Teach LiteLLM the per-token cost of every deployment in
+ ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
+ from the OpenRouter catalogue + any operator-declared YAML pricing).
+
+ Must run AFTER ``initialize_openrouter_integration()`` so the
+ OpenRouter catalogue is populated and BEFORE the first LLM call so
+ ``response_cost`` is available in ``TokenTrackingCallback``.
+
+ Failures are logged but never raised — startup must not be blocked
+ by a missing pricing entry; the worst-case is the model debits 0.
+ """
+ try:
+ from app.services.pricing_registration import (
+ register_pricing_from_global_configs,
+ )
+
+ register_pricing_from_global_configs()
+ except Exception as e:
+ print(f"Warning: Failed to register LiteLLM pricing: {e}")
+
+
def initialize_llm_router():
"""
Initialize the LLM Router service for Auto mode.
@@ -389,14 +537,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
- # Premium token quota settings
- PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
+ # Premium credit (micro-USD) quota settings.
+ #
+ # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
+ # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
+ # still honoured for one release as fall-back values — the prior
+ # $1-per-1M-tokens Stripe price means every existing value maps 1:1
+ # to micros, so operators upgrading without changing their .env still
+ # get correct behaviour. A startup deprecation warning fires below if
+ # they're set.
+ PREMIUM_CREDIT_MICROS_LIMIT = int(
+ os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
+ or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
+ )
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
- STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
+ STRIPE_CREDIT_MICROS_PER_UNIT = int(
+ os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
+ or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
+ )
STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
)
+ # Safety ceiling on the per-call premium reservation. ``stream_new_chat``
+ # estimates an upper-bound cost from ``litellm.get_model_info`` x the
+ # config's ``quota_reserve_tokens`` and clamps the result to this value
+ # so a misconfigured "$1000/M" model can't lock the user's whole balance
+ # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
+ # reserve_tokens ≈ $0.36) with headroom.
+ QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
+
+ if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
+ "PREMIUM_CREDIT_MICROS_LIMIT"
+ ):
+ print(
+ "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
+ "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
+ "current Stripe price). The old key will be removed in a "
+ "future release."
+ )
+ if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
+ "STRIPE_CREDIT_MICROS_PER_UNIT"
+ ):
+ print(
+ "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
+ "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
+ "The old key will be removed in a future release."
+ )
+
# Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
MULTI_AGENT_CHAT_ENABLED = (
@@ -412,6 +600,35 @@ class Config:
# Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
+ # Per-image reservation (in micro-USD) used by ``billable_call`` for the
+ # ``POST /image-generations`` endpoint when the global config does not
+ # override it. $0.05 covers realistic worst-cases for current OpenAI /
+ # OpenRouter image-gen pricing. Bypassed entirely for free configs.
+ QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
+ )
+
+ # Per-podcast reservation (in micro-USD). One agent LLM call generating
+ # a transcript, typically 5k-20k completion tokens. $0.20 covers a long
+ # premium-model run. Tune via env.
+ QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
+ )
+
+ # Per-video-presentation reservation (in micro-USD). Fan-out of N
+ # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
+ # plus refine retries; can produce many premium completions. $1.00
+ # covers worst-case. Tune via env.
+ #
+ # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
+ # 1_000_000. The override path in ``billable_call`` bypasses the
+ # per-call clamp in ``estimate_call_reserve_micros``, so this is the
+ # *actual* hold — raising it via env is fine but means a single video
+ # task can lock $1+ of credit.
+ QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
+ )
+
# Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml
index 9aca0f022..d92640c8d 100644
--- a/surfsense_backend/app/config/global_llm_config.example.yaml
+++ b/surfsense_backend/app/config/global_llm_config.example.yaml
@@ -19,6 +19,24 @@
# Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
+#
+# COST-BASED PREMIUM CREDITS:
+# Each premium config bills the user's USD-credit balance based on the
+# actual provider cost reported by LiteLLM. For models LiteLLM already
+# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
+# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
+# or any model LiteLLM doesn't have in its built-in pricing table, declare
+# per-token costs inline so they bill correctly:
+#
+# litellm_params:
+# base_model: "my-custom-azure-deploy"
+# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
+# input_cost_per_token: 0.000003
+# output_cost_per_token: 0.000015
+#
+# OpenRouter dynamic models pull pricing automatically from OpenRouter's
+# API — no inline declaration needed. Models without resolvable pricing
+# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models
@@ -245,31 +263,64 @@ global_llm_configs:
# =============================================================================
# When enabled, dynamically fetches ALL available models from the OpenRouter API
# and injects them as global configs. This gives premium users access to any model
-# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
+# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
+# while free-tier OpenRouter models show up with a green Free badge and do NOT
+# consume premium quota.
# Models are fetched at startup and refreshed periodically in the background.
# All calls go through LiteLLM with the openrouter/ prefix.
openrouter_integration:
enabled: false
api_key: "sk-or-your-openrouter-api-key"
- # billing_tier: "premium" or "free". Controls whether users need premium tokens.
- billing_tier: "premium"
- # anonymous_enabled: set true to also show OpenRouter models to no-login users
- anonymous_enabled: false
+
+ # Tier is derived PER MODEL from OpenRouter's own API signals:
+ # - id ends with ":free" -> billing_tier=free
+ # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
+ # - otherwise -> billing_tier=premium
+ # No global billing_tier knob is honored; any legacy value emits a startup warning.
+
+ # Anonymous access is split by tier so operators can expose only free
+ # models to no-login users without leaking paid inference.
+ anonymous_enabled_paid: false
+ anonymous_enabled_free: false
+
seo_enabled: false
# quota_reserve_tokens: tokens reserved per call for quota enforcement
quota_reserve_tokens: 4000
- # id_offset: starting negative ID for dynamically generated configs.
- # Must not overlap with your static global_llm_configs IDs above.
+ # id_offset: base negative ID for dynamically generated configs.
+ # Model IDs are derived deterministically via BLAKE2b so they survive
+ # catalogue churn. Must not overlap with your static global_llm_configs IDs.
id_offset: -10000
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
refresh_interval_hours: 24
- # rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
- # OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
- # upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
- # These values only matter if you set billing_tier to "free" (adding them to Auto mode).
- # For premium-only models they are cosmetic. Set conservatively or match your account tier.
+
+ # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
+ # for per-deployment accounting when OR premium models participate in the
+ # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
+ # real account limits live at https://openrouter.ai/settings/limits.
rpm: 200
tpm: 1000000
+
+ # Rate limits for FREE OpenRouter models. Informational only: free OR
+ # models are intentionally kept OUT of the LiteLLM Router pool, because
+ # OpenRouter enforces free-tier limits globally per account (~20 RPM +
+ # 50-1000 daily requests across every ":free" model combined) —
+ # per-deployment router accounting can't represent a shared bucket
+ # correctly. Free OR models stay fully available in the model selector
+ # and for user-facing Auto thread pinning.
+ free_rpm: 20
+ free_tpm: 100000
+
+ # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
+ # contains hundreds of image- and vision-capable models; turning these on
+ # injects them into the global Image-Generation / Vision-LLM model
+ # selectors alongside any static configs. Tier (free/premium) is derived
+ # per model the same way it is for chat (`:free` suffix or zero pricing).
+ # When a user picks a premium image/vision model the call debits the
+ # shared $5 USD-cost-based premium credit pool — so leaving these off
+ # avoids surprise quota burn on existing deployments. Default: false.
+ image_generation_enabled: false
+ vision_enabled: false
+
litellm_params:
max_tokens: 16384
system_instructions: ""
diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py
index 91d19fb4f..9fc27fb1f 100644
--- a/surfsense_backend/app/db.py
+++ b/surfsense_backend/app/db.py
@@ -638,6 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False,
server_default="false",
)
+ # Auto (Fastest) model pin for this thread: concrete resolved global LLM
+ # config id. NULL means no pin; Auto will resolve on the next turn.
+ # Single-writer invariant: only app.services.auto_model_pin_service sets
+ # or clears this column (plus bulk clears when a search space's
+ # agent_llm_id changes). Unindexed: all reads are by primary key.
+ pinned_llm_config_id = Column(Integer, nullable=True)
# Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
@@ -669,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
__tablename__ = "new_chat_messages"
+ # Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
+ # Mirrors alembic migration 141. Lets the streaming agent and the
+ # legacy frontend appendMessage call coexist idempotently — the second
+ # writer trips the unique and recovers without creating a duplicate row.
+ # Partial so legacy NULL turn_id rows and clone/snapshot inserts in
+ # app/services/public_chat_service.py (which omit turn_id) are unaffected.
+ __table_args__ = (
+ Index(
+ "uq_new_chat_messages_thread_turn_role",
+ "thread_id",
+ "turn_id",
+ "role",
+ unique=True,
+ postgresql_where=text("turn_id IS NOT NULL"),
+ ),
+ )
+
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
# Content stored as JSONB to support rich content (text, tool calls, etc.)
content = Column(JSONB, nullable=False)
@@ -722,9 +745,26 @@ class TokenUsage(BaseModel, TimestampMixin):
__tablename__ = "token_usage"
+ # Partial unique index on (message_id) where message_id IS NOT NULL.
+ # Mirrors alembic migration 142. Lets the streaming agent's
+ # ``finalize_assistant_turn`` and the legacy frontend ``append_message``
+ # recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
+ # racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
+ # (indexing, image generation, podcasts) — which keep ``message_id`` NULL
+ # because there is no per-message anchor — are unaffected.
+ __table_args__ = (
+ Index(
+ "uq_token_usage_message_id",
+ "message_id",
+ unique=True,
+ postgresql_where=text("message_id IS NOT NULL"),
+ ),
+ )
+
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)
+ cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True)
@@ -1787,7 +1827,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin):
- """Tracks Stripe checkout sessions used to grant additional premium token credits."""
+ """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
+
+ Note: the table name is preserved (``premium_token_purchases``) for
+ operational continuity even though the unit is now USD micro-credits
+ instead of raw tokens. The ``credit_micros_granted`` column replaced
+ the legacy ``tokens_granted`` in migration 140; the stored values
+ were not transformed because the prior $1 = 1M tokens Stripe price
+ makes the unit conversion 1:1 numerically.
+ """
__tablename__ = "premium_token_purchases"
__allow_unmapped__ = True
@@ -1804,7 +1852,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
)
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False)
- tokens_granted = Column(BigInteger, nullable=False)
+ credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True)
status = Column(
@@ -2103,16 +2151,16 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
- premium_tokens_limit = Column(
+ premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
- default=config.PREMIUM_TOKEN_LIMIT,
- server_default=str(config.PREMIUM_TOKEN_LIMIT),
+ default=config.PREMIUM_CREDIT_MICROS_LIMIT,
+ server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
- premium_tokens_used = Column(
+ premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
- premium_tokens_reserved = Column(
+ premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
@@ -2235,16 +2283,16 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
- premium_tokens_limit = Column(
+ premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
- default=config.PREMIUM_TOKEN_LIMIT,
- server_default=str(config.PREMIUM_TOKEN_LIMIT),
+ default=config.PREMIUM_CREDIT_MICROS_LIMIT,
+ server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
- premium_tokens_used = Column(
+ premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
- premium_tokens_reserved = Column(
+ premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
index 4bb38b7b0..d45bd780c 100644
--- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
+++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
@@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM",
content_type="image",
)
- except Exception:
- logging.warning(
- "Vision LLM failed for %s, falling back to document parser",
- request.filename,
- exc_info=True,
- )
+ except Exception as exc:
+ # Special-case quota exhaustion so we log a clearer message
+ # — the vision LLM didn't "fail", the user just ran out of
+ # premium credit. Falling through to the document parser
+ # is a graceful degradation: OCR/Unstructured still
+ # extracts text from the image without burning credit.
+ from app.services.billable_calls import QuotaInsufficientError
+
+ if isinstance(exc, QuotaInsufficientError):
+ logging.info(
+ "Vision LLM quota exhausted for %s; falling back to document parser",
+ request.filename,
+ )
+ else:
+ logging.warning(
+ "Vision LLM failed for %s, falling back to document parser",
+ request.filename,
+ exc_info=True,
+ )
else:
logging.info(
"No vision LLM provided, falling back to document parser for %s",
diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py
index 5732a8dfb..99388af66 100644
--- a/surfsense_backend/app/routes/agent_flags_route.py
+++ b/surfsense_backend/app/routes/agent_flags_route.py
@@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
+from app.config import config
from app.db import User
from app.users import current_active_user
@@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool
+ enable_desktop_local_filesystem: bool
+
@classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
- return cls(**asdict(flags))
+ return cls(
+ **asdict(flags),
+ enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
+ )
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
diff --git a/surfsense_backend/app/routes/composio_routes.py b/surfsense_backend/app/routes/composio_routes.py
index 4bf360365..7bc2addf8 100644
--- a/surfsense_backend/app/routes/composio_routes.py
+++ b/surfsense_backend/app/routes/composio_routes.py
@@ -649,13 +649,9 @@ async def list_composio_drive_folders(
"""
List folders AND files in user's Google Drive via Composio.
- Uses the same GoogleDriveClient / list_folder_contents path as the native
- connector, with Composio-sourced credentials. This means auth errors
- propagate identically (Google returns 401 → exception → auth_expired flag).
+ Uses Composio's Google Drive tool execution path so managed OAuth tokens
+ do not need to be exposed through connected account state.
"""
- from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
- from app.utils.google_credentials import build_composio_credentials
-
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@@ -689,10 +685,37 @@ async def list_composio_drive_folders(
detail="Composio connected account not found. Please reconnect the connector.",
)
- credentials = build_composio_credentials(composio_connected_account_id)
- drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
+ service = ComposioService()
+ entity_id = f"surfsense_{user.id}"
+ items = []
+ page_token = None
+ error = None
- items, error = await list_folder_contents(drive_client, parent_id=parent_id)
+ while True:
+ page_items, next_token, page_error = await service.get_drive_files(
+ connected_account_id=composio_connected_account_id,
+ entity_id=entity_id,
+ folder_id=parent_id,
+ page_token=page_token,
+ page_size=100,
+ )
+ if page_error:
+ error = page_error
+ break
+
+ items.extend(page_items)
+ if not next_token:
+ break
+ page_token = next_token
+
+ for item in items:
+ item["isFolder"] = (
+ item.get("mimeType") == "application/vnd.google-apps.folder"
+ )
+
+ items.sort(
+ key=lambda item: (not item["isFolder"], item.get("name", "").lower())
+ )
if error:
error_lower = error.lower()
diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py
index f558481cf..f1ca3b6bf 100644
--- a/surfsense_backend/app/routes/documents_routes.py
+++ b/surfsense_backend/app/routes/documents_routes.py
@@ -745,6 +745,51 @@ async def search_document_titles(
) from e
+@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead)
+async def get_document_by_virtual_path(
+ search_space_id: int,
+ virtual_path: str,
+ session: AsyncSession = Depends(get_async_session),
+ user: User = Depends(current_active_user),
+):
+ """Resolve a knowledge-base document id by exact virtual path."""
+ try:
+ await check_permission(
+ session,
+ user,
+ search_space_id,
+ Permission.DOCUMENTS_READ.value,
+ "You don't have permission to read documents in this search space",
+ )
+
+ result = await session.execute(
+ select(
+ Document.id,
+ Document.title,
+ Document.document_type,
+ ).filter(
+ Document.search_space_id == search_space_id,
+ Document.document_metadata["virtual_path"].as_string() == virtual_path,
+ )
+ )
+ row = result.first()
+ if row is None:
+ raise HTTPException(status_code=404, detail="Document not found")
+
+ return DocumentTitleRead(
+ id=row.id,
+ title=row.title,
+ document_type=row.document_type,
+ )
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to resolve document by virtual path: {e!s}",
+ ) from e
+
+
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
async def get_documents_status(
search_space_id: int,
diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py
index 97a3559b9..018234ad5 100644
--- a/surfsense_backend/app/routes/image_generation_routes.py
+++ b/surfsense_backend/app/routes/image_generation_routes.py
@@ -36,11 +36,17 @@ from app.schemas import (
ImageGenerationListRead,
ImageGenerationRead,
)
+from app.services.billable_calls import (
+ DEFAULT_IMAGE_RESERVE_MICROS,
+ QuotaInsufficientError,
+ billable_call,
+)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
+from app.services.provider_api_base import resolve_api_base
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None
+def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
+ """Resolve the LiteLLM provider prefix used in model strings."""
+ if custom_provider:
+ return custom_provider
+ return _PROVIDER_MAP.get(provider.upper(), provider.lower())
+
+
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
- if custom_provider:
- return f"{custom_provider}/{model_name}"
- prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
- return f"{prefix}/{model_name}"
+ return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
+
+
+async def _resolve_billing_for_image_gen(
+ session: AsyncSession,
+ config_id: int | None,
+ search_space: SearchSpace,
+) -> tuple[str, str, int]:
+ """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
+
+ The resolution mirrors ``_execute_image_generation``'s lookup tree but
+ only extracts the fields needed for billing — we do this *before*
+ ``billable_call`` so the reservation is correctly sized for the
+ config that will actually run, and so we don't open an
+ ``ImageGeneration`` row for a request that's about to 402.
+
+ User-owned (positive ID) BYOK configs are always free — they cost
+ the user nothing on our side. Auto mode currently treats as free
+ because the underlying router can dispatch to either premium or
+ free YAML configs and we don't surface the resolved deployment up
+ here yet. Bringing Auto under premium billing would require
+ threading the chosen deployment back from ``ImageGenRouterService``.
+ """
+ resolved_id = config_id
+ if resolved_id is None:
+ resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
+
+ if is_image_gen_auto_mode(resolved_id):
+ return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
+
+ if resolved_id < 0:
+ cfg = _get_global_image_gen_config(resolved_id) or {}
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
+ base_model = _build_model_string(
+ cfg.get("provider", ""),
+ cfg.get("model_name", ""),
+ cfg.get("custom_provider"),
+ )
+ reserve_micros = int(
+ cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
+ )
+ return (billing_tier, base_model, reserve_micros)
+
+ # Positive ID = user-owned BYOK image-gen config — always free.
+ return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
async def _execute_image_generation(
@@ -138,12 +192,18 @@ async def _execute_image_generation(
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
- model_string = _build_model_string(
- cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
+ provider_prefix = _resolve_provider_prefix(
+ cfg.get("provider", ""), cfg.get("custom_provider")
)
+ model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
- if cfg.get("api_base"):
- gen_kwargs["api_base"] = cfg["api_base"]
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=provider_prefix,
+ config_api_base=cfg.get("api_base"),
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@@ -165,12 +225,18 @@ async def _execute_image_generation(
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
- model_string = _build_model_string(
- db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
+ provider_prefix = _resolve_provider_prefix(
+ db_cfg.provider.value, db_cfg.custom_provider
)
+ model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
- if db_cfg.api_base:
- gen_kwargs["api_base"] = db_cfg.api_base
+ api_base = resolve_api_base(
+ provider=db_cfg.provider.value,
+ provider_prefix=provider_prefix,
+ config_api_base=db_cfg.api_base,
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
@@ -225,10 +291,15 @@ async def get_global_image_gen_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
+ # Auto mode currently treated as free until per-deployment
+ # billing-tier surfacing lands (see _resolve_billing_for_image_gen).
+ "billing_tier": "free",
+ "is_premium": False,
}
)
for cfg in global_configs:
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@@ -241,6 +312,12 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": billing_tier,
+ # Mirror chat (``new_llm_config_routes``) so the new-chat
+ # selector's premium badge logic keys off the same
+ # field across chat / image / vision tabs.
+ "is_premium": billing_tier == "premium",
+ "quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
@@ -454,7 +531,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
- """Create and execute an image generation request."""
+ """Create and execute an image generation request.
+
+ Premium configs are gated by the user's shared premium credit pool.
+ The flow is:
+
+ 1. Permission check + load the search space (cheap, no provider call).
+ 2. Resolve which config will run so we know its billing tier and the
+ worst-case reservation size *before* opening any DB rows.
+ 3. Wrap the entire ImageGeneration row insert + provider call in
+ ``billable_call``. If quota is denied, ``billable_call`` raises
+ ``QuotaInsufficientError`` *before* we flush a row, which we
+ translate to HTTP 402 (no orphaned rows on the user's account,
+ no inserted error rows for "you ran out of credit").
+ 4. On success, the actual ``response_cost`` flows through the
+ LiteLLM callback into the accumulator, and ``billable_call``
+ finalizes the debit at exit. Inner ``try/except`` still catches
+ provider errors and stores them on ``error_message`` (HTTP 200
+ with ``error_message`` set is preserved for failed-but-not-quota
+ scenarios — clients already know how to surface those).
+ """
try:
await check_permission(
session,
@@ -471,33 +567,70 @@ async def create_image_generation(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
- db_image_gen = ImageGeneration(
- prompt=data.prompt,
- model=data.model,
- n=data.n,
- quality=data.quality,
- size=data.size,
- style=data.style,
- response_format=data.response_format,
- image_generation_config_id=data.image_generation_config_id,
- search_space_id=data.search_space_id,
- created_by_id=user.id,
+ billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
+ session, data.image_generation_config_id, search_space
)
- session.add(db_image_gen)
- await session.flush()
- try:
- await _execute_image_generation(session, db_image_gen, search_space)
- except Exception as e:
- logger.exception("Image generation call failed")
- db_image_gen.error_message = str(e)
+ # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
+ # propagates to the outer ``except QuotaInsufficientError`` handler
+ # below as HTTP 402 — it is intentionally NOT swallowed into
+ # ``error_message`` because that would (1) imply a successful row
+ # exists when none does, and (2) return HTTP 200 to a client
+ # whose request was actively *denied* (issue K).
+ async with billable_call(
+ user_id=search_space.user_id,
+ search_space_id=data.search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=reserve_micros,
+ usage_type="image_generation",
+ call_details={"model": base_model, "prompt": data.prompt[:100]},
+ ):
+ db_image_gen = ImageGeneration(
+ prompt=data.prompt,
+ model=data.model,
+ n=data.n,
+ quality=data.quality,
+ size=data.size,
+ style=data.style,
+ response_format=data.response_format,
+ image_generation_config_id=data.image_generation_config_id,
+ search_space_id=data.search_space_id,
+ created_by_id=user.id,
+ )
+ session.add(db_image_gen)
+ await session.flush()
- await session.commit()
- await session.refresh(db_image_gen)
- return db_image_gen
+ try:
+ await _execute_image_generation(session, db_image_gen, search_space)
+ except Exception as e:
+ logger.exception("Image generation call failed")
+ db_image_gen.error_message = str(e)
+
+ await session.commit()
+ await session.refresh(db_image_gen)
+ return db_image_gen
except HTTPException:
raise
+ except QuotaInsufficientError as exc:
+ # The user's premium credit pool is empty. No DB row is created
+ # because ``billable_call`` denies before yielding (issue K).
+ await session.rollback()
+ raise HTTPException(
+ status_code=402,
+ detail={
+ "error_code": "premium_quota_exhausted",
+ "usage_type": exc.usage_type,
+ "used_micros": exc.used_micros,
+ "limit_micros": exc.limit_micros,
+ "remaining_micros": exc.remaining_micros,
+ "message": (
+ "Out of premium credits for image generation. "
+ "Purchase additional credits or switch to a free model."
+ ),
+ },
+ ) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py
index c95553fce..ad96654f5 100644
--- a/surfsense_backend/app/routes/new_chat_routes.py
+++ b/surfsense_backend/app/routes/new_chat_routes.py
@@ -15,9 +15,10 @@ import json
import logging
from datetime import UTC, datetime
-from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
-from sqlalchemy import func, or_
+from sqlalchemy import func, or_, text as sa_text
+from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@@ -29,6 +30,12 @@ from app.agents.new_chat.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
+from app.agents.new_chat.middleware.busy_mutex import (
+ get_cancel_state,
+ is_cancel_requested,
+ manager,
+ request_cancel,
+)
from app.config import config
from app.db import (
ChatComment,
@@ -38,12 +45,14 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
+ TokenUsage,
User,
get_async_session,
shielded_async_session,
)
from app.schemas.new_chat import (
AgentToolInfo,
+ CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
@@ -60,10 +69,11 @@ from app.schemas.new_chat import (
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
+ TurnStatusResponse,
)
-from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
+from app.utils.perf import get_perf_logger
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
split_langchain_human_content,
@@ -71,7 +81,11 @@ from app.utils.user_message_multimodal import (
)
_logger = logging.getLogger(__name__)
+_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
+TURN_CANCELLING_INITIAL_DELAY_MS = 200
+TURN_CANCELLING_BACKOFF_FACTOR = 2
+TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
@@ -137,6 +151,72 @@ def _resolve_filesystem_selection(
)
+def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
+ """Bounded exponential delay for TURN_CANCELLING retry hints."""
+ if attempt < 1:
+ attempt = 1
+ delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
+ TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
+ )
+ return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
+
+
+def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
+ lock = manager.lock_for(str(thread_id))
+ if not lock.locked():
+ return {"status": "idle"}
+
+ if is_cancel_requested(str(thread_id)):
+ cancel_state = get_cancel_state(str(thread_id))
+ attempt = cancel_state[0] if cancel_state else 1
+ retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
+ retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
+ return {
+ "status": "cancelling",
+ "retry_after_ms": retry_after_ms,
+ "retry_after_at": retry_after_at,
+ }
+
+ return {"status": "busy"}
+
+
+def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
+ response.headers["retry-after-ms"] = str(retry_after_ms)
+ response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
+
+
+def _raise_if_thread_busy_for_start(thread_id: int) -> None:
+ status_payload = _build_turn_status_payload(thread_id)
+ status = status_payload["status"]
+ if status == "idle":
+ return
+ if status == "cancelling":
+ retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
+ detail = {
+ "errorCode": "TURN_CANCELLING",
+ "message": "A previous response is still stopping. Please try again in a moment.",
+ "retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
+ "retry_after_at": status_payload.get("retry_after_at"),
+ }
+ headers = (
+ {
+ "retry-after-ms": str(retry_after_ms),
+ "Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
+ }
+ if retry_after_ms > 0
+ else None
+ )
+ raise HTTPException(status_code=409, detail=detail, headers=headers)
+
+ raise HTTPException(
+ status_code=409,
+ detail={
+ "errorCode": "THREAD_BUSY",
+ "message": "Another response is still finishing for this thread. Please try again in a moment.",
+ },
+ )
+
+
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
@@ -1210,6 +1290,24 @@ async def append_message(
user: User = Depends(current_active_user),
):
"""
+ .. deprecated:: 2026-05
+ Replaced by the **SSE-based message ID handshake**. The streaming
+ generator (`stream_new_chat` / `stream_resume_chat`) now persists
+ both the user and assistant rows server-side via
+ ``persist_user_turn`` / ``persist_assistant_shell`` and emits
+ ``data-user-message-id`` / ``data-assistant-message-id`` SSE events
+ so the frontend can rename its optimistic IDs in real time. The
+ new FE bundle no longer calls this route.
+
+ This handler is retained as a **silent no-op for legacy / cached
+ FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
+ pattern means a stale bundle hitting this route after the SSE
+ handshake already wrote the row simply returns the existing row
+ (200 OK) without raising or duplicating data. After a 2-week soak
+ (target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
+ this entire route — and the FE ``appendMessage`` function — is
+ earmarked for removal.
+
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
@@ -1220,6 +1318,22 @@ async def append_message(
Requires CHATS_UPDATE permission.
"""
try:
+ # Capture ``user.id`` as a primitive UUID up front. The
+ # ``current_active_user`` dependency hands us a ``User`` ORM
+ # row bound to ``session``; if the outer ``except
+ # IntegrityError`` block below ever fires (an unexpected
+ # constraint like a foreign key violation — the common
+ # ``(thread_id, turn_id, role)`` race is now handled silently
+ # by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
+ # ``session.rollback()``, which expires every attached ORM
+ # row including this user. Any later ``user.id`` access would
+ # then trigger a lazy PK reload — which on async SQLAlchemy
+ # fails with ``MissingGreenlet`` because the reload happens
+ # outside the awaitable greenlet boundary. Reading ``id``
+ # once here pins the value as a plain UUID so all downstream
+ # uses (TokenUsage insert, response build) are immune.
+ user_uuid = user.id
+
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
@@ -1274,37 +1388,166 @@ async def append_message(
else None
)
- db_message = NewChatMessage(
- thread_id=thread_id,
- role=message_role,
- content=content,
- author_id=user.id,
- turn_id=turn_id_value,
- )
- session.add(db_message)
-
- # Update thread's updated_at timestamp
+ # Update thread's updated_at timestamp (always — both insert
+ # and recovery paths represent thread activity).
thread.updated_at = datetime.now(UTC)
- # flush assigns the PK/defaults without a round-trip SELECT
- await session.flush()
+ # Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
+ # keyed on the ``(thread_id, turn_id, role)`` partial unique
+ # index from migration 141 (``WHERE turn_id IS NOT NULL``).
+ #
+ # Why ON CONFLICT instead of ``session.add() + flush() + except
+ # IntegrityError``:
+ # 1. The conflict between this legacy FE ``appendMessage``
+ # round-trip and the server-side
+ # ``finalize_assistant_turn`` writer is a NORMAL,
+ # *expected* race — every assistant turn fires it. Using
+ # catch-and-recover means asyncpg raises
+ # ``UniqueViolationError`` -> SQLAlchemy wraps it as
+ # ``IntegrityError`` -> our handler catches and recovers.
+ # Functionally fine, but every ``raise`` event lights up
+ # VS Code's debugger (debugpy's ``justMyCode=false`` mode
+ # loses track of the catch frame across SQLAlchemy's
+ # async greenlet boundary, so even ``Raised Exceptions``
+ # being unchecked doesn't reliably suppress the pause).
+ # ON CONFLICT pushes the conflict resolution into Postgres
+ # where no Python exception is constructed at all.
+ # 2. No ``session.rollback()`` -> no expiring of attached
+ # ORM rows -> no risk of ``MissingGreenlet`` from
+ # lazy-loading expired user/thread state later in the
+ # handler.
+ # 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
+ # tracebacks emitted by uvicorn's logger between the
+ # ``raise`` and our ``except``).
+ #
+ # When ``turn_id_value`` is ``None`` the partial index doesn't
+ # apply and the INSERT proceeds normally. Other constraint
+ # violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
+ # and are caught by the outer ``except IntegrityError`` block
+ # to preserve the legacy 400 behavior.
+ #
+ # Note on ``content``: when we recover the existing row, we
+ # intentionally discard the FE's ``content`` payload from
+ # ``raw_body`` and return the row's existing ``content``. The
+ # streaming task is now the *authoritative writer* for
+ # assistant ``ContentPart[]`` shape (mid-stream
+ # ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
+ # so the FE's later ``appendMessage`` is just a stale snapshot
+ # of the same data — keeping the server-built rich content
+ # (with full tool-call args / argsText / langchainToolCallId)
+ # is correct, not lossy.
+ insert_stmt = (
+ pg_insert(NewChatMessage)
+ .values(
+ thread_id=thread_id,
+ role=message_role,
+ content=content,
+ author_id=user_uuid,
+ turn_id=turn_id_value,
+ )
+ .on_conflict_do_nothing(
+ index_elements=["thread_id", "turn_id", "role"],
+ index_where=sa_text("turn_id IS NOT NULL"),
+ )
+ .returning(NewChatMessage.id)
+ )
+ inserted_id = (await session.execute(insert_stmt)).scalar()
- # Persist token usage if provided (for assistant messages)
+ if inserted_id is None:
+ # Conflict on partial unique index — server-side stream
+ # already wrote this row. Look it up and reuse it.
+ if turn_id_value is None:
+ # Defensive: ON CONFLICT only fires for ``turn_id IS
+ # NOT NULL`` rows, so this branch should be
+ # unreachable. Preserve the legacy 400 just in case
+ # Postgres ever surprises us.
+ raise HTTPException(
+ status_code=400,
+ detail="Database constraint violation. Please check your input data.",
+ ) from None
+ lookup = await session.execute(
+ select(NewChatMessage).filter(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id_value,
+ NewChatMessage.role == message_role,
+ )
+ )
+ existing_message = lookup.scalars().first()
+ if existing_message is None:
+ # Conflict reported but the row vanished between
+ # INSERT and SELECT — extremely unlikely (would
+ # require a concurrent DELETE within the same
+ # transaction visibility), but preserve safe
+ # behavior.
+ raise HTTPException(
+ status_code=400,
+ detail="Database constraint violation. Please check your input data.",
+ ) from None
+ db_message = existing_message
+ # Perf signal: counts how often the legacy FE round-trip
+ # races the server-side ``finalize_assistant_turn``. A
+ # rising rate after the rework is OK (it's exactly the
+ # ghost-thread fix's recovery path firing); a sudden drop
+ # to zero would mean the FE isn't posting appendMessage
+ # at all (different bug).
+ _perf_log.info(
+ "[append_message] outcome=recovered_via_unique_index "
+ "thread_id=%s turn_id=%s role=%s message_id=%s",
+ thread_id,
+ turn_id_value,
+ message_role.value,
+ db_message.id,
+ )
+ else:
+ # INSERT succeeded — load the full ORM row so the
+ # response can include server-side-defaulted columns
+ # (``created_at``, etc.) and the relationship surface
+ # stays consistent with the recovery path.
+ inserted_row = await session.get(NewChatMessage, inserted_id)
+ if inserted_row is None:
+ # Should be impossible: we just inserted it in this
+ # same transaction. Fail loud if it happens.
+ raise HTTPException(
+ status_code=500,
+ detail="Inserted message could not be loaded.",
+ ) from None
+ db_message = inserted_row
+
+ # Persist token usage if provided (for assistant messages).
+ # ``cost_micros`` is the provider USD cost reported by LiteLLM,
+ # forwarded by the FE through the appendMessage round-trip so
+ # the historical TokenUsage row matches the credit debit applied
+ # at finalize time.
+ #
+ # De-dup: ``finalize_assistant_turn`` may also race to write a
+ # token_usage row for this same ``message_id`` (cross-session,
+ # cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
+ # on the ``uq_token_usage_message_id`` partial unique index
+ # (migration 142). The loser silently drops its insert; exactly
+ # one row results regardless of which writer commits first.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
- await record_token_usage(
- session,
- usage_type="chat",
- search_space_id=thread.search_space_id,
- user_id=user.id,
- prompt_tokens=token_usage_data.get("prompt_tokens", 0),
- completion_tokens=token_usage_data.get("completion_tokens", 0),
- total_tokens=token_usage_data.get("total_tokens", 0),
- model_breakdown=token_usage_data.get("usage"),
- call_details=token_usage_data.get("call_details"),
- thread_id=thread_id,
- message_id=db_message.id,
+ insert_stmt = (
+ pg_insert(TokenUsage)
+ .values(
+ usage_type="chat",
+ prompt_tokens=token_usage_data.get("prompt_tokens", 0),
+ completion_tokens=token_usage_data.get("completion_tokens", 0),
+ total_tokens=token_usage_data.get("total_tokens", 0),
+ cost_micros=token_usage_data.get("cost_micros", 0),
+ model_breakdown=token_usage_data.get("usage"),
+ call_details=token_usage_data.get("call_details"),
+ thread_id=thread_id,
+ message_id=db_message.id,
+ search_space_id=thread.search_space_id,
+ user_id=user_uuid,
+ )
+ .on_conflict_do_nothing(
+ index_elements=["message_id"],
+ index_where=sa_text("message_id IS NOT NULL"),
+ )
)
+ await session.execute(insert_stmt)
await session.commit()
@@ -1324,6 +1567,9 @@ async def append_message(
except HTTPException:
raise
except IntegrityError:
+ # Any IntegrityError that escaped the inline handler above
+ # comes from a *different* constraint (foreign key, etc.) —
+ # preserve the legacy 400 path.
await session.rollback()
raise HTTPException(
status_code=400,
@@ -1476,6 +1722,7 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
+ _raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@@ -1516,6 +1763,12 @@ async def handle_new_chat(
else None
)
+ mentioned_documents_payload = (
+ [doc.model_dump() for doc in request.mentioned_documents]
+ if request.mentioned_documents
+ else None
+ )
+
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
@@ -1525,6 +1778,7 @@ async def handle_new_chat(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
+ mentioned_documents=mentioned_documents_payload,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
@@ -1550,6 +1804,93 @@ async def handle_new_chat(
) from None
+@router.post(
+ "/threads/{thread_id}/cancel-active-turn",
+ response_model=CancelActiveTurnResponse,
+)
+async def cancel_active_turn(
+ thread_id: int,
+ response: Response,
+ session: AsyncSession = Depends(get_async_session),
+ user: User = Depends(current_active_user),
+):
+ """Signal cancellation for the currently running turn on ``thread_id``."""
+ result = await session.execute(
+ select(NewChatThread).filter(NewChatThread.id == thread_id)
+ )
+ thread = result.scalars().first()
+ if not thread:
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ await check_permission(
+ session,
+ user,
+ thread.search_space_id,
+ Permission.CHATS_UPDATE.value,
+ "You don't have permission to update chats in this search space",
+ )
+ await check_thread_access(session, thread, user)
+
+ status_payload = _build_turn_status_payload(thread_id)
+ if status_payload["status"] == "idle":
+ return CancelActiveTurnResponse(
+ status="idle",
+ error_code="NO_ACTIVE_TURN",
+ )
+
+ request_cancel(str(thread_id))
+ response.status_code = 202
+ updated_payload = _build_turn_status_payload(thread_id)
+ retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
+ retry_after_at = (
+ int(updated_payload["retry_after_at"])
+ if "retry_after_at" in updated_payload
+ else None
+ )
+ if retry_after_ms > 0:
+ _set_retry_after_headers(response, retry_after_ms)
+ return CancelActiveTurnResponse(
+ status="cancelling",
+ error_code="TURN_CANCELLING",
+ retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
+ retry_after_at=retry_after_at,
+ )
+
+
+@router.get(
+ "/threads/{thread_id}/turn-status",
+ response_model=TurnStatusResponse,
+)
+async def get_turn_status(
+ thread_id: int,
+ session: AsyncSession = Depends(get_async_session),
+ user: User = Depends(current_active_user),
+):
+ result = await session.execute(
+ select(NewChatThread).filter(NewChatThread.id == thread_id)
+ )
+ thread = result.scalars().first()
+ if not thread:
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ await check_permission(
+ session,
+ user,
+ thread.search_space_id,
+ Permission.CHATS_READ.value,
+ "You don't have permission to view chats in this search space",
+ )
+ await check_thread_access(session, thread, user)
+
+ status_payload = _build_turn_status_payload(thread_id)
+ return TurnStatusResponse(
+ status=status_payload["status"], # type: ignore[arg-type]
+ active_turn_id=None,
+ retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
+ retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
+ )
+
+
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@@ -1605,6 +1946,7 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
+ _raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@@ -1907,6 +2249,11 @@ async def regenerate_response(
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
+ mentioned_documents_payload = (
+ [doc.model_dump() for doc in request.mentioned_documents]
+ if request.mentioned_documents
+ else None
+ )
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),
@@ -1916,6 +2263,7 @@ async def regenerate_response(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
+ mentioned_documents=mentioned_documents_payload,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
@@ -1924,6 +2272,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
+ flow="regenerate",
):
yield chunk
streaming_completed = True
@@ -2011,6 +2360,7 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
+ _raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py
index 20779a309..e090a1a7c 100644
--- a/surfsense_backend/app/routes/new_llm_config_routes.py
+++ b/surfsense_backend/app/routes/new_llm_config_routes.py
@@ -29,6 +29,7 @@ from app.schemas import (
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
+from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user
from app.utils.rbac import check_permission
@@ -36,6 +37,39 @@ router = APIRouter()
logger = logging.getLogger(__name__)
+def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
+ """Augment a BYOK chat config row with the derived ``supports_image_input``.
+
+ There is no DB column for ``supports_image_input`` — the value is
+ resolved at the API boundary from LiteLLM's authoritative model map
+ (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
+ the response shape consistent across list / detail / create / update
+ endpoints without having to remember to set the field at every call
+ site.
+ """
+ provider_value = (
+ config.provider.value
+ if hasattr(config.provider, "value")
+ else str(config.provider)
+ )
+ litellm_params = config.litellm_params or {}
+ base_model = (
+ litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
+ )
+ supports_image_input = derive_supports_image_input(
+ provider=provider_value,
+ model_name=config.model_name,
+ base_model=base_model,
+ custom_provider=config.custom_provider,
+ )
+ # ``model_validate`` runs the Pydantic conversion using the ORM
+ # attribute access path enabled by ``ConfigDict(from_attributes=True)``,
+ # then we layer the derived field on. ``model_copy(update=...)`` keeps
+ # the surface immutable from the caller's perspective.
+ base_read = NewLLMConfigRead.model_validate(config)
+ return base_read.model_copy(update={"supports_image_input": supports_image_input})
+
+
# =============================================================================
# Global Configs Routes
# =============================================================================
@@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
+ # Auto routes across the configured pool, which usually
+ # includes at least one vision-capable deployment, so
+ # treat Auto as image-capable. The router itself will
+ # still pick a vision-capable deployment for messages
+ # carrying image_url blocks (LiteLLM Router falls back
+ # on ``404`` per its ``allowed_fails`` policy).
+ "supports_image_input": True,
}
)
# Add individual global configs
for cfg in global_configs:
+ # Capability resolution: explicit value (YAML override or OR
+ # `_supports_image_input(model)` payload baked in by the
+ # OpenRouter integration service) wins. Fall back to the
+ # LiteLLM-driven helper which default-allows on unknown so
+ # we don't hide vision-capable models that happen to lack a
+ # YAML annotation. The streaming task safety net is the
+ # only place a False ever blocks.
+ if "supports_image_input" in cfg:
+ supports_image_input = bool(cfg.get("supports_image_input"))
+ else:
+ cfg_litellm_params = cfg.get("litellm_params") or {}
+ cfg_base_model = (
+ cfg_litellm_params.get("base_model")
+ if isinstance(cfg_litellm_params, dict)
+ else None
+ )
+ supports_image_input = derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=cfg_base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
@@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
+ "supports_image_input": supports_image_input,
}
safe_configs.append(safe_config)
@@ -171,7 +236,7 @@ async def create_new_llm_config(
await session.commit()
await session.refresh(db_config)
- return db_config
+ return _serialize_byok_config(db_config)
except HTTPException:
raise
@@ -213,7 +278,7 @@ async def list_new_llm_configs(
.limit(limit)
)
- return result.scalars().all()
+ return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException:
raise
@@ -268,7 +333,7 @@ async def get_new_llm_config(
"You don't have permission to view LLM configurations in this search space",
)
- return config
+ return _serialize_byok_config(config)
except HTTPException:
raise
@@ -360,7 +425,7 @@ async def update_new_llm_config(
await session.commit()
await session.refresh(config)
- return config
+ return _serialize_byok_config(config)
except HTTPException:
raise
diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py
index 828137518..5ecfb1814 100644
--- a/surfsense_backend/app/routes/search_spaces_routes.py
+++ b/surfsense_backend/app/routes/search_spaces_routes.py
@@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel
-from sqlalchemy import func
+from sqlalchemy import func, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
from app.config import config
from app.db import (
ImageGenerationConfig,
+ NewChatThread,
NewLLMConfig,
Permission,
SearchSpace,
@@ -593,6 +594,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
+ "billing_tier": "free",
}
if config_id < 0:
@@ -609,6 +611,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
}
return None
@@ -651,6 +654,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
+ "billing_tier": "free",
}
if config_id < 0:
@@ -667,6 +671,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
}
return None
@@ -790,9 +795,27 @@ async def update_llm_preferences(
# Update preferences
update_data = preferences.model_dump(exclude_unset=True)
+ previous_agent_llm_id = search_space.agent_llm_id
for key, value in update_data.items():
setattr(search_space, key, value)
+ agent_llm_changed = (
+ "agent_llm_id" in update_data
+ and update_data["agent_llm_id"] != previous_agent_llm_id
+ )
+ if agent_llm_changed:
+ await session.execute(
+ update(NewChatThread)
+ .where(NewChatThread.search_space_id == search_space_id)
+ .values(pinned_llm_config_id=None)
+ )
+ logger.info(
+ "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
+ search_space_id,
+ previous_agent_llm_id,
+ update_data["agent_llm_id"],
+ )
+
await session.commit()
await session.refresh(search_space)
diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py
index cfdd4b52a..aed74ec8d 100644
--- a/surfsense_backend/app/routes/stripe_routes.py
+++ b/surfsense_backend/app/routes/stripe_routes.py
@@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
- tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
+ # Read the new metadata key first, fall back to the legacy one so
+ # in-flight checkout sessions created before the cost-credits
+ # release still fulfil correctly (the unit is numerically the
+ # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
+ credit_micros_per_unit = int(
+ metadata.get("credit_micros_per_unit")
+ or metadata.get("tokens_per_unit", "0")
+ )
- if not user_id or quantity <= 0 or tokens_per_unit <= 0:
+ if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
@@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
- tokens_granted=quantity * tokens_per_unit,
+ credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
- user.premium_tokens_limit = (
- max(user.premium_tokens_used, user.premium_tokens_limit)
- + purchase.tokens_granted
+ # Top up the user's credit balance by the granted micro-USD amount.
+ # ``max(used, limit)`` clamps the case where the legacy code wrote a
+ # used value above the limit (e.g. underbilling rounding) so adding
+ # ``credit_micros_granted`` always lifts the limit by the full pack
+ # size rather than disappearing into past overuse.
+ user.premium_credit_micros_limit = (
+ max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ + purchase.credit_micros_granted
)
await db_session.commit()
@@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
- """Create a Stripe Checkout Session for buying premium token packs."""
+ """Create a Stripe Checkout Session for buying premium credit packs.
+
+ Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
+ credit (default 1_000_000 = $1.00). The user's balance is debited
+ at the actual provider cost reported by LiteLLM at finalize time,
+ so $1 of credit always buys $1 worth of provider usage at cost.
+ """
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
- tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
+ credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
@@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
- "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
- "purchase_type": "premium_tokens",
+ "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
+ "purchase_type": "premium_credit",
},
}
)
@@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
- tokens_granted=tokens_granted,
+ credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status(
user: User = Depends(current_active_user),
):
- """Return token-buying availability and current premium quota for frontend."""
- used = user.premium_tokens_used
- limit = user.premium_tokens_limit
+ """Return token-buying availability and current premium credit quota for frontend.
+
+ Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
+ when displaying. The route name is preserved for back-compat with
+ pinned client deployments.
+ """
+ used = user.premium_credit_micros_used
+ limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
- premium_tokens_used=used,
- premium_tokens_limit=limit,
- premium_tokens_remaining=max(0, limit - used),
+ premium_credit_micros_used=used,
+ premium_credit_micros_limit=limit,
+ premium_credit_micros_remaining=max(0, limit - used),
)
diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py
index 315c7c9fe..e4f08f604 100644
--- a/surfsense_backend/app/routes/vision_llm_routes.py
+++ b/surfsense_backend/app/routes/vision_llm_routes.py
@@ -82,10 +82,15 @@ async def get_global_vision_llm_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
+ # Auto mode treated as free until per-deployment billing-tier
+ # surfacing lands; see ``get_vision_llm`` for parity.
+ "billing_tier": "free",
+ "is_premium": False,
}
)
for cfg in global_configs:
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@@ -98,6 +103,14 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": billing_tier,
+ # Mirror chat (``new_llm_config_routes``) so the new-chat
+ # selector's premium badge logic keys off the same
+ # field across chat / image / vision tabs.
+ "is_premium": billing_tier == "premium",
+ "quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
+ "input_cost_per_token": cfg.get("input_cost_per_token"),
+ "output_cost_per_token": cfg.get("output_cost_per_token"),
}
)
diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py
index 69f534e20..4262b2b3f 100644
--- a/surfsense_backend/app/schemas/image_generation.py
+++ b/surfsense_backend/app/schemas/image_generation.py
@@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
+
+ The ``billing_tier`` field allows the frontend to show a Premium/Free
+ badge and (more importantly) tells the backend whether to debit the
+ user's premium credit pool when this config is used. ``"free"`` is
+ the default for backward compatibility — admins must explicitly opt
+ a global config into ``"premium"``.
"""
id: int = Field(
@@ -231,3 +237,24 @@ class GlobalImageGenConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
+ billing_tier: str = Field(
+ default="free",
+ description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
+ )
+ is_premium: bool = Field(
+ default=False,
+ description=(
+ "Convenience boolean derived server-side from "
+ "``billing_tier == 'premium'``. The new-chat model selector "
+ "keys its Free/Premium badge off this field for parity with "
+ "chat (`GlobalLLMConfigRead.is_premium`)."
+ ),
+ )
+ quota_reserve_micros: int | None = Field(
+ default=None,
+ description=(
+ "Optional override for the reservation amount (in micro-USD) used when "
+ "this image generation is premium. Falls back to "
+ "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
+ ),
+ )
diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py
index cfb4b8b37..95d183433 100644
--- a/surfsense_backend/app/schemas/new_chat.py
+++ b/surfsense_backend/app/schemas/new_chat.py
@@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
+ cost_micros: int = 0
model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True)
@@ -199,6 +200,21 @@ class NewChatUserImagePart(BaseModel):
return to_data_url(self.media_type, self.data)
+class MentionedDocumentInfo(BaseModel):
+ """Display metadata for a single ``@``-mentioned document.
+
+ The full triple ``{id, title, document_type}`` is forwarded by the
+ frontend mention chip so the server can embed it in the persisted
+ user message ``ContentPart[]`` (single ``mentioned-documents`` part).
+ The history loader then renders the chips on reload without an extra
+ fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape.
+ """
+
+ id: int
+ title: str = Field(..., min_length=1, max_length=500)
+ document_type: str = Field(..., min_length=1, max_length=100)
+
+
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@@ -212,6 +228,17 @@ class NewChatRequest(BaseModel):
mentioned_surfsense_doc_ids: list[int] | None = (
None # Optional SurfSense documentation IDs mentioned with @ in the chat
)
+ mentioned_documents: list[MentionedDocumentInfo] | None = Field(
+ default=None,
+ description=(
+ "Display metadata (id, title, document_type) for every "
+ "@-mentioned document. Persisted as a ``mentioned-documents`` "
+ "ContentPart on the user message so reload renders chips "
+ "without an extra fetch. Optional and additive — when None "
+ "the user message is persisted without a mentioned-documents "
+ "part."
+ ),
+ )
disabled_tools: list[str] | None = (
None # Optional list of tool names the user has disabled from the UI
)
@@ -263,6 +290,16 @@ class RegenerateRequest(BaseModel):
)
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
+ mentioned_documents: list[MentionedDocumentInfo] | None = Field(
+ default=None,
+ description=(
+ "Display metadata (id, title, document_type) for every "
+ "@-mentioned document on the edited user turn. Only used "
+ "when ``user_query`` is non-None (edit). Persisted as a "
+ "``mentioned-documents`` ContentPart on the new user "
+ "message. None means no chip metadata."
+ ),
+ )
disabled_tools: list[str] | None = None
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
@@ -336,6 +373,34 @@ class ResumeRequest(BaseModel):
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
+ mentioned_documents: list[MentionedDocumentInfo] | None = Field(
+ default=None,
+ description=(
+ "Display metadata forwarded for symmetry with /new_chat and "
+ "/regenerate. Resume reuses the original interrupted user "
+ "turn so the server does not write a new user message. "
+ "Currently unused but accepted to keep request bodies "
+ "uniform across the three streaming entrypoints."
+ ),
+ )
+
+
+class CancelActiveTurnResponse(BaseModel):
+ """Response for canceling an active turn on a chat thread."""
+
+ status: Literal["cancelling", "idle"]
+ error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
+ retry_after_ms: int | None = None
+ retry_after_at: int | None = None
+
+
+class TurnStatusResponse(BaseModel):
+ """Current turn execution status for a thread."""
+
+ status: Literal["idle", "busy", "cancelling"]
+ active_turn_id: str | None = None
+ retry_after_ms: int | None = None
+ retry_after_at: int | None = None
# =============================================================================
diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py
index 9cc1fce58..e64478d38 100644
--- a/surfsense_backend/app/schemas/new_llm_config.py
+++ b/surfsense_backend/app/schemas/new_llm_config.py
@@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
+ # Capability flag derived at the API boundary (no DB column). Default
+ # True matches the conservative-allow stance — a BYOK row that the
+ # route forgot to augment is not pre-judged. The streaming-task
+ # safety net is the only place a False actually blocks a request.
+ supports_image_input: bool = Field(
+ default=True,
+ description=(
+ "Whether the BYOK chat config can accept image inputs. Derived "
+ "at the route boundary from LiteLLM's authoritative model map "
+ "(``litellm.supports_vision``) — there is no DB column. "
+ "Default True is the conservative-allow stance for unknown / "
+ "unmapped models."
+ ),
+ )
model_config = ConfigDict(from_attributes=True)
@@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
+ # Capability flag derived at the API boundary (see NewLLMConfigRead).
+ supports_image_input: bool = Field(
+ default=True,
+ description=(
+ "Whether the BYOK chat config can accept image inputs. Derived "
+ "at the route boundary from LiteLLM's authoritative model map. "
+ "Default True is the conservative-allow stance."
+ ),
+ )
model_config = ConfigDict(from_attributes=True)
@@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
+ supports_image_input: bool = Field(
+ default=True,
+ description=(
+ "Whether the model accepts image inputs (multimodal vision). "
+ "Derived server-side: OpenRouter dynamic configs use "
+ "``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
+ "authoritative model map (``litellm.supports_vision``). The "
+ "new-chat selector hints with a 'No image' badge when this is "
+ "False and there are pending image attachments. The streaming "
+ "task fails fast only when LiteLLM *explicitly* marks a model "
+ "as text-only — unknown / unmapped models default-allow."
+ ),
+ )
# =============================================================================
diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py
index 3edd3e9e4..57265ec8e 100644
--- a/surfsense_backend/app/schemas/stripe.py
+++ b/surfsense_backend/app/schemas/stripe.py
@@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
class TokenPurchaseRead(BaseModel):
- """Serialized premium token purchase record."""
+ """Serialized premium credit purchase record.
+
+ ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
+ schema name kept ``Token`` for API back-compat with pinned clients.
+ """
id: uuid.UUID
stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None
quantity: int
- tokens_granted: int
+ credit_micros_granted: int
amount_total: int | None = None
currency: str | None = None
status: str
@@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
class TokenPurchaseHistoryResponse(BaseModel):
- """Response containing the user's premium token purchases."""
+ """Response containing the user's premium credit purchases."""
purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel):
- """Response describing token-buying availability and current quota."""
+ """Response describing premium-credit-buying availability and balance.
+
+ All ``premium_credit_micros_*`` fields are in micro-USD; the FE
+ divides by 1_000_000 to display USD.
+ """
token_buying_enabled: bool
- premium_tokens_used: int = 0
- premium_tokens_limit: int = 0
- premium_tokens_remaining: int = 0
+ premium_credit_micros_used: int = 0
+ premium_credit_micros_limit: int = 0
+ premium_credit_micros_remaining: int = 0
diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py
index ab2e609dc..d0eeaf5c6 100644
--- a/surfsense_backend/app/schemas/vision_llm.py
+++ b/surfsense_backend/app/schemas/vision_llm.py
@@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
class GlobalVisionLLMConfigRead(BaseModel):
+ """Schema for reading global vision LLM configs from YAML.
+
+ The ``billing_tier`` field allows the frontend to show a Premium/Free
+ badge and (more importantly) tells the backend whether to debit the
+ user's premium credit pool when this config is used. ``"free"`` is
+ the default for backward compatibility — admins must explicitly opt
+ a global config into ``"premium"``.
+ """
+
id: int = Field(...)
name: str
description: str | None = None
@@ -73,3 +82,35 @@ class GlobalVisionLLMConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
+ billing_tier: str = Field(
+ default="free",
+ description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
+ )
+ is_premium: bool = Field(
+ default=False,
+ description=(
+ "Convenience boolean derived server-side from "
+ "``billing_tier == 'premium'``. The new-chat model selector "
+ "keys its Free/Premium badge off this field for parity with "
+ "chat (`GlobalLLMConfigRead.is_premium`)."
+ ),
+ )
+ quota_reserve_tokens: int | None = Field(
+ default=None,
+ description=(
+ "Optional override for the per-call reservation in *tokens* — "
+ "converted to micro-USD via the model's input/output prices at "
+ "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
+ ),
+ )
+ input_cost_per_token: float | None = Field(
+ default=None,
+ description=(
+ "Optional input price in USD/token. Used by pricing_registration to "
+ "register custom Azure / OpenRouter aliases with LiteLLM at startup."
+ ),
+ )
+ output_cost_per_token: float | None = Field(
+ default=None,
+ description="Optional output price in USD/token. Pair with input_cost_per_token.",
+ )
diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py
new file mode 100644
index 000000000..9bbca8669
--- /dev/null
+++ b/surfsense_backend/app/services/auto_model_pin_service.py
@@ -0,0 +1,479 @@
+"""Resolve and persist Auto (Fastest) model pins per chat thread.
+
+Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
+resolve that virtual mode to one concrete global LLM config exactly once and
+persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
+subsequent turns are stable.
+
+Single-writer invariant: this module is the only writer of
+``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
+``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
+Therefore a non-NULL value unambiguously means "this thread has an
+Auto-resolved pin"; no separate source/policy column is needed.
+"""
+
+from __future__ import annotations
+
+import hashlib
+import logging
+import threading
+import time
+from dataclasses import dataclass
+from uuid import UUID
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.config import config
+from app.db import NewChatThread
+from app.services.quality_score import _QUALITY_TOP_K
+from app.services.token_quota_service import TokenQuotaService
+
+logger = logging.getLogger(__name__)
+
+AUTO_FASTEST_ID = 0
+AUTO_FASTEST_MODE = "auto_fastest"
+_RUNTIME_COOLDOWN_SECONDS = 600
+_HEALTHY_TTL_SECONDS = 45
+
+# In-memory runtime cooldown map for configs that recently hard-failed at
+# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
+# the same unhealthy config from being reselected immediately during repair.
+_runtime_cooldown_until: dict[int, float] = {}
+_runtime_cooldown_lock = threading.Lock()
+
+# Short-TTL "recently healthy" cache for configs that just passed a runtime
+# preflight ping. Lets back-to-back turns on the same model skip the probe
+# without eroding correctness — entries auto-expire and are wiped any time
+# the same config is cooled down or the OR catalogue is refreshed.
+_healthy_until: dict[int, float] = {}
+_healthy_lock = threading.Lock()
+
+
+@dataclass
+class AutoPinResolution:
+ resolved_llm_config_id: int
+ resolved_tier: str
+ from_existing_pin: bool
+
+
+def _is_usable_global_config(cfg: dict) -> bool:
+ return bool(
+ cfg.get("id") is not None
+ and cfg.get("model_name")
+ and cfg.get("provider")
+ and cfg.get("api_key")
+ )
+
+
+def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
+ now = time.time() if now_ts is None else now_ts
+ stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
+ for cid in stale:
+ _runtime_cooldown_until.pop(cid, None)
+
+
+def _is_runtime_cooled_down(config_id: int) -> bool:
+ with _runtime_cooldown_lock:
+ _prune_runtime_cooldowns()
+ return config_id in _runtime_cooldown_until
+
+
+def mark_runtime_cooldown(
+ config_id: int,
+ *,
+ reason: str = "rate_limited",
+ cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
+) -> None:
+ """Temporarily suppress a config from Auto selection.
+
+ Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
+ config that is currently unhealthy does not get immediately reused on the
+ same thread during repair.
+ """
+ if cooldown_seconds <= 0:
+ cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
+ until = time.time() + int(cooldown_seconds)
+ with _runtime_cooldown_lock:
+ _runtime_cooldown_until[int(config_id)] = until
+ _prune_runtime_cooldowns()
+ # A cooled cfg can never be "recently healthy"; drop any stale credit so
+ # the next turn that resolves to it (after cooldown) re-runs preflight.
+ clear_healthy(int(config_id))
+ logger.info(
+ "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
+ config_id,
+ reason,
+ cooldown_seconds,
+ )
+
+
+def clear_runtime_cooldown(config_id: int | None = None) -> None:
+ """Test/ops helper to clear runtime cooldown entries."""
+ with _runtime_cooldown_lock:
+ if config_id is None:
+ _runtime_cooldown_until.clear()
+ return
+ _runtime_cooldown_until.pop(int(config_id), None)
+
+
+def _prune_healthy(now_ts: float | None = None) -> None:
+ now = time.time() if now_ts is None else now_ts
+ stale = [cid for cid, until in _healthy_until.items() if until <= now]
+ for cid in stale:
+ _healthy_until.pop(cid, None)
+
+
+def is_recently_healthy(config_id: int) -> bool:
+ """Return True if ``config_id`` passed preflight within the TTL window."""
+ with _healthy_lock:
+ _prune_healthy()
+ return int(config_id) in _healthy_until
+
+
+def mark_healthy(
+ config_id: int,
+ *,
+ ttl_seconds: int = _HEALTHY_TTL_SECONDS,
+) -> None:
+ """Record that ``config_id`` just passed a preflight probe.
+
+ Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The
+ healthy state is intentionally process-local — it's a latency hint, not a
+ correctness primitive — so multi-worker drift is acceptable.
+ """
+ if ttl_seconds <= 0:
+ ttl_seconds = _HEALTHY_TTL_SECONDS
+ until = time.time() + int(ttl_seconds)
+ with _healthy_lock:
+ _healthy_until[int(config_id)] = until
+ _prune_healthy()
+
+
+def clear_healthy(config_id: int | None = None) -> None:
+ """Drop one (or all) healthy-cache entries.
+
+ Called from runtime cooldown and OR catalogue refresh so a freshly cooled
+ or replaced config never carries stale "healthy" credit.
+ """
+ with _healthy_lock:
+ if config_id is None:
+ _healthy_until.clear()
+ return
+ _healthy_until.pop(int(config_id), None)
+
+
+def _cfg_supports_image_input(cfg: dict) -> bool:
+ """True if the global cfg can accept image inputs.
+
+ Prefers the explicit ``supports_image_input`` flag (set by the YAML
+ loader / OpenRouter integration). Falls back to a LiteLLM lookup so
+ a YAML entry whose flag was somehow stripped doesn't get wrongly
+ excluded. Default-allows on unknown — the streaming-task safety net
+ is the actual block, not this filter.
+ """
+ if "supports_image_input" in cfg:
+ return bool(cfg.get("supports_image_input"))
+ # Lazy import: provider_capabilities -> llm_config -> services chain;
+ # importing at module load would create an init-order cycle through
+ # ``app.config``.
+ from app.services.provider_capabilities import derive_supports_image_input
+
+ cfg_litellm_params = cfg.get("litellm_params") or {}
+ base_model = (
+ cfg_litellm_params.get("base_model")
+ if isinstance(cfg_litellm_params, dict)
+ else None
+ )
+ return derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+
+
+def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
+ """Return Auto-eligible global cfgs.
+
+ Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
+ below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
+ can't be picked as the thread's pin. Also excludes configs currently
+ in runtime cooldown (e.g. temporary 429 bursts).
+
+ When ``requires_image_input`` is True (image turn), additionally
+ filters out configs whose ``supports_image_input`` resolves to False
+ so a text-only deployment can't be pinned for an image request.
+ """
+ candidates = [
+ cfg
+ for cfg in config.GLOBAL_LLM_CONFIGS
+ if _is_usable_global_config(cfg)
+ and not cfg.get("health_gated")
+ and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
+ and (not requires_image_input or _cfg_supports_image_input(cfg))
+ ]
+ return sorted(candidates, key=lambda c: int(c.get("id", 0)))
+
+
+def _tier_of(cfg: dict) -> str:
+ return str(cfg.get("billing_tier", "free")).lower()
+
+
+def _is_preferred_premium_auto_config(cfg: dict) -> bool:
+ """Return True for the operator-preferred premium Auto model."""
+ return (
+ _tier_of(cfg) == "premium"
+ and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
+ and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
+ )
+
+
+def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
+ """Pick a config with quality-first ranking + deterministic spread.
+
+ Tier policy is lock-first: prefer Tier A (operator-curated YAML)
+ cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
+ Tier A cfg is eligible after upstream filters. Within the locked
+ pool, sort by ``quality_score`` and pick from the top-K via
+ ``SHA256(thread_id)`` so different new threads spread across the
+ best models without ever picking a low-ranked one.
+
+ Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
+ structured logging in the caller.
+ """
+ tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
+ pool = tier_a if tier_a else eligible
+ pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
+ top_k = pool[:_QUALITY_TOP_K]
+ digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
+ idx = int.from_bytes(digest[:8], "big") % len(top_k)
+ return top_k[idx], len(top_k)
+
+
+def _to_uuid(user_id: str | UUID | None) -> UUID | None:
+ if user_id is None:
+ return None
+ if isinstance(user_id, UUID):
+ return user_id
+ try:
+ return UUID(str(user_id))
+ except Exception:
+ return None
+
+
+async def _is_premium_eligible(
+ session: AsyncSession, user_id: str | UUID | None
+) -> bool:
+ parsed = _to_uuid(user_id)
+ if parsed is None:
+ return False
+ usage = await TokenQuotaService.premium_get_usage(session, parsed)
+ return bool(usage.allowed)
+
+
+async def resolve_or_get_pinned_llm_config_id(
+ session: AsyncSession,
+ *,
+ thread_id: int,
+ search_space_id: int,
+ user_id: str | UUID | None,
+ selected_llm_config_id: int,
+ force_repin_free: bool = False,
+ exclude_config_ids: set[int] | None = None,
+ requires_image_input: bool = False,
+) -> AutoPinResolution:
+ """Resolve Auto (Fastest) to one concrete config id and persist the pin.
+
+ For non-auto selections, this function clears any existing pin and returns
+ the selected id as-is.
+
+ When ``requires_image_input`` is True (the current turn carries an
+ ``image_url`` block), the candidate pool is filtered to vision-capable
+ cfgs and any existing pin that can't accept image input is treated as
+ invalid (force re-pin). If no vision-capable cfg is available the
+ function raises ``ValueError`` so the streaming task surfaces the same
+ friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
+ silently routing the image to a text-only deployment.
+ """
+ thread = (
+ (
+ await session.execute(
+ select(NewChatThread)
+ .where(NewChatThread.id == thread_id)
+ .with_for_update(of=NewChatThread)
+ )
+ )
+ .unique()
+ .scalar_one_or_none()
+ )
+ if thread is None:
+ raise ValueError(f"Thread {thread_id} not found")
+ if thread.search_space_id != search_space_id:
+ raise ValueError(
+ f"Thread {thread_id} does not belong to search space {search_space_id}"
+ )
+
+ # Explicit model selected: clear any stale pin.
+ if selected_llm_config_id != AUTO_FASTEST_ID:
+ if thread.pinned_llm_config_id is not None:
+ thread.pinned_llm_config_id = None
+ await session.commit()
+ return AutoPinResolution(
+ resolved_llm_config_id=selected_llm_config_id,
+ resolved_tier="explicit",
+ from_existing_pin=False,
+ )
+
+ excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
+ candidates = [
+ c
+ for c in _global_candidates(requires_image_input=requires_image_input)
+ if int(c.get("id", 0)) not in excluded_ids
+ ]
+ if not candidates:
+ if requires_image_input:
+ # Distinguish the "no vision-capable cfg" case from generic
+ # "no usable cfg" so the streaming task can map this to the
+ # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
+ raise ValueError(
+ "No vision-capable global LLM configs are available for Auto mode"
+ )
+ raise ValueError("No usable global LLM configs are available for Auto mode")
+ candidate_by_id = {int(c["id"]): c for c in candidates}
+
+ # Reuse an existing valid pin without re-checking current quota (no silent
+ # tier switch), unless the caller explicitly requests a forced repin to free
+ # *or* the turn requires image input but the pin can't handle it.
+ pinned_id = thread.pinned_llm_config_id
+ if (
+ not force_repin_free
+ and pinned_id is not None
+ and int(pinned_id) in candidate_by_id
+ ):
+ pinned_cfg = candidate_by_id[int(pinned_id)]
+ logger.info(
+ "auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
+ thread_id,
+ search_space_id,
+ pinned_id,
+ _tier_of(pinned_cfg),
+ )
+ logger.info(
+ "auto_pin_resolved thread_id=%s config_id=%s tier=%s "
+ "auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
+ thread_id,
+ pinned_id,
+ _tier_of(pinned_cfg),
+ pinned_cfg.get("auto_pin_tier", "?"),
+ int(pinned_cfg.get("quality_score") or 0),
+ )
+ return AutoPinResolution(
+ resolved_llm_config_id=int(pinned_id),
+ resolved_tier=_tier_of(pinned_cfg),
+ from_existing_pin=True,
+ )
+ if pinned_id is not None:
+ # If the pin is *only* invalid because it can't handle the image
+ # turn (it's still a healthy, usable config in the broader pool),
+ # log that explicitly so operators can correlate the re-pin with
+ # the user's image attachment instead of suspecting a cooldown.
+ if requires_image_input:
+ try:
+ pinned_global = next(
+ c
+ for c in config.GLOBAL_LLM_CONFIGS
+ if int(c.get("id", 0)) == int(pinned_id)
+ )
+ except StopIteration:
+ pinned_global = None
+ if pinned_global is not None and not _cfg_supports_image_input(
+ pinned_global
+ ):
+ logger.info(
+ "auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
+ "previous_config_id=%s",
+ thread_id,
+ search_space_id,
+ pinned_id,
+ )
+ logger.info(
+ "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
+ thread_id,
+ search_space_id,
+ pinned_id,
+ )
+
+ premium_eligible = (
+ False if force_repin_free else await _is_premium_eligible(session, user_id)
+ )
+ if premium_eligible:
+ premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
+ preferred_premium = [
+ c for c in premium_candidates if _is_preferred_premium_auto_config(c)
+ ]
+ eligible = preferred_premium or premium_candidates
+ else:
+ eligible = [c for c in candidates if _tier_of(c) != "premium"]
+
+ if not eligible:
+ if requires_image_input:
+ raise ValueError(
+ "Auto mode could not find a vision-capable LLM config for this user and quota state"
+ )
+ raise ValueError(
+ "Auto mode could not find an eligible LLM config for this user and quota state"
+ )
+
+ selected_cfg, top_k_size = _select_pin(eligible, thread_id)
+ selected_id = int(selected_cfg["id"])
+ selected_tier = _tier_of(selected_cfg)
+
+ thread.pinned_llm_config_id = selected_id
+ await session.commit()
+
+ if force_repin_free:
+ logger.info(
+ "auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
+ thread_id,
+ search_space_id,
+ pinned_id,
+ selected_id,
+ )
+
+ if pinned_id is None:
+ logger.info(
+ "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
+ thread_id,
+ search_space_id,
+ selected_id,
+ selected_tier,
+ premium_eligible,
+ )
+ else:
+ logger.info(
+ "auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
+ thread_id,
+ search_space_id,
+ pinned_id,
+ selected_id,
+ selected_tier,
+ premium_eligible,
+ )
+
+ logger.info(
+ "auto_pin_resolved thread_id=%s config_id=%s tier=%s "
+ "auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
+ thread_id,
+ selected_id,
+ selected_tier,
+ selected_cfg.get("auto_pin_tier", "?"),
+ int(selected_cfg.get("quality_score") or 0),
+ top_k_size,
+ )
+
+ return AutoPinResolution(
+ resolved_llm_config_id=selected_id,
+ resolved_tier=selected_tier,
+ from_existing_pin=False,
+ )
diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py
new file mode 100644
index 000000000..92ccd6a78
--- /dev/null
+++ b/surfsense_backend/app/services/billable_calls.py
@@ -0,0 +1,566 @@
+"""
+Per-call billable wrapper for image generation, vision LLM extraction, and
+any other short-lived premium operation that must charge against the user's
+shared premium credit pool.
+
+The ``billable_call`` async context manager encapsulates the standard
+"reserve → execute → finalize / release → record audit row" lifecycle in a
+single primitive so callers (the image-generation REST route and the
+vision-LLM wrapper used during indexing) don't have to re-implement it.
+
+KEY DESIGN POINTS (issue A, B):
+
+1. **Session isolation.** ``billable_call`` takes no caller transaction.
+ All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
+ inside their own session context. Route callers use
+ ``shielded_async_session()`` by default; Celery callers can provide a
+ worker-loop-safe session factory. This guarantees that quota
+ commit/rollback can never accidentally flush or roll back rows the caller
+ has staged in its main session (e.g. a freshly-created
+ ``ImageGeneration`` row).
+
+2. **ContextVar safety.** The accumulator is scoped via
+ :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
+ nested ``billable_call`` inside an outer chat turn cannot corrupt the
+ chat turn's accumulator.
+
+3. **Free configs are still audited.** Free calls bypass the reserve /
+ finalize dance entirely but still record a ``TokenUsage`` audit row with
+ the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
+ pipeline complete for analytics even when nothing is debited.
+
+4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
+ responsible for translating that into HTTP 402. We *do not* catch the
+ denial inside ``billable_call`` — letting it propagate also prevents
+ the image-generation route from creating an ``ImageGeneration`` row
+ for a request that never actually ran.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from collections.abc import AsyncIterator, Callable
+from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
+from typing import Any
+from uuid import UUID, uuid4
+
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.config import config
+from app.db import shielded_async_session
+from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+)
+from app.services.token_tracking_service import (
+ TurnTokenAccumulator,
+ record_token_usage,
+ scoped_turn,
+)
+
+logger = logging.getLogger(__name__)
+
+AUDIT_TIMEOUT_SECONDS = 10.0
+BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
+ {"video_presentation_generation", "podcast_generation"}
+)
+BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
+
+
+class QuotaInsufficientError(Exception):
+ """Raised when ``TokenQuotaService.premium_reserve`` denies a billable
+ call because the user has exhausted their premium credit pool.
+
+ The route handler should catch this and return HTTP 402 Payment
+ Required (or the equivalent for the surface area). Outside of the HTTP
+ layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
+ callers may catch this and degrade gracefully — e.g. fall back to OCR
+ when vision is unavailable.
+ """
+
+ def __init__(
+ self,
+ *,
+ usage_type: str,
+ used_micros: int,
+ limit_micros: int,
+ remaining_micros: int,
+ ) -> None:
+ self.usage_type = usage_type
+ self.used_micros = used_micros
+ self.limit_micros = limit_micros
+ self.remaining_micros = remaining_micros
+ super().__init__(
+ f"Premium credit exhausted for {usage_type}: "
+ f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
+ )
+
+
+class BillingSettlementError(Exception):
+ """Raised when a premium call completed but credit settlement failed."""
+
+ def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
+ self.usage_type = usage_type
+ self.user_id = user_id
+ super().__init__(
+ f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
+ )
+
+
+async def _rollback_safely(session: AsyncSession) -> None:
+ rollback = getattr(session, "rollback", None)
+ if rollback is not None:
+ with suppress(Exception):
+ await rollback()
+
+
+async def _record_audit_best_effort(
+ *,
+ session_factory: BillableSessionFactory,
+ usage_type: str,
+ search_space_id: int,
+ user_id: UUID,
+ prompt_tokens: int,
+ completion_tokens: int,
+ total_tokens: int,
+ cost_micros: int,
+ model_breakdown: dict[str, Any],
+ call_details: dict[str, Any] | None,
+ thread_id: int | None,
+ message_id: int | None,
+ audit_label: str,
+ timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
+) -> None:
+ """Persist a TokenUsage row without letting audit failure block callers.
+
+ Premium settlement is mandatory, but TokenUsage is an audit trail. If the
+ audit insert or commit hangs, user-facing artifacts such as videos and
+ podcasts must still be able to transition to READY after settlement.
+ """
+ audit_thread_id = (
+ None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
+ )
+
+ async def _persist() -> None:
+ logger.info(
+ "[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
+ "total_tokens=%d cost_micros=%d",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ total_tokens,
+ cost_micros,
+ )
+ async with session_factory() as audit_session:
+ try:
+ await record_token_usage(
+ audit_session,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ model_breakdown=model_breakdown,
+ call_details=call_details,
+ thread_id=audit_thread_id,
+ message_id=message_id,
+ )
+ logger.info(
+ "[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ )
+ await audit_session.commit()
+ logger.info(
+ "[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ )
+ except BaseException:
+ await _rollback_safely(audit_session)
+ raise
+
+ try:
+ await asyncio.wait_for(_persist(), timeout=timeout_seconds)
+ except TimeoutError:
+ logger.warning(
+ "[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
+ "timeout=%.1fs total_tokens=%d cost_micros=%d",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ timeout_seconds,
+ total_tokens,
+ cost_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
+ "total_tokens=%d cost_micros=%d",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ total_tokens,
+ cost_micros,
+ )
+
+
+@asynccontextmanager
+async def billable_call(
+ *,
+ user_id: UUID,
+ search_space_id: int,
+ billing_tier: str,
+ base_model: str,
+ quota_reserve_tokens: int | None = None,
+ quota_reserve_micros_override: int | None = None,
+ usage_type: str,
+ thread_id: int | None = None,
+ message_id: int | None = None,
+ call_details: dict[str, Any] | None = None,
+ billable_session_factory: BillableSessionFactory | None = None,
+ audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
+) -> AsyncIterator[TurnTokenAccumulator]:
+ """Wrap a single billable LLM/image call.
+
+ Args:
+ user_id: Owner of the credit pool to debit. For vision-LLM during
+ indexing this is the *search-space owner* (issue M), not the
+ triggering user.
+ search_space_id: Required — recorded on the ``TokenUsage`` audit row.
+ billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
+ the reserve/finalize dance but still records an audit row with
+ the captured cost.
+ base_model: Used by :func:`estimate_call_reserve_micros` to compute
+ a worst-case reservation from LiteLLM's pricing table.
+ quota_reserve_tokens: Optional per-config override for the chat-style
+ reserve estimator (vision LLM uses this).
+ quota_reserve_micros_override: Optional flat micro-USD reservation
+ (image generation uses this — its cost shape is per-image, not
+ per-token).
+ usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
+ Recorded on the ``TokenUsage`` row.
+ thread_id, message_id: Optional FK columns on ``TokenUsage``.
+ call_details: Optional per-call metadata (model name, parameters)
+ forwarded to ``record_token_usage``.
+ billable_session_factory: Optional async context factory used for
+ reserve/finalize/release/audit sessions. Defaults to
+ ``shielded_async_session`` for route callers; Celery callers pass
+ a worker-loop-safe session factory.
+ audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
+ Audit failure is best-effort and does not undo successful
+ settlement.
+
+ Yields:
+ The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
+ the underlying LLM/image API while inside the ``async with``; the
+ ``TokenTrackingCallback`` populates the accumulator automatically.
+
+ Raises:
+ QuotaInsufficientError: when premium and ``premium_reserve`` denies.
+ """
+ is_premium = billing_tier == "premium"
+ session_factory = billable_session_factory or shielded_async_session
+
+ async with scoped_turn() as acc:
+ # ---------- Free path: just audit -------------------------------
+ if not is_premium:
+ try:
+ yield acc
+ finally:
+ # Always audit, even on exception, so we capture cost when
+ # provider returns successfully but the caller raises later.
+ await _record_audit_best_effort(
+ session_factory=session_factory,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=acc.total_cost_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ audit_label="free",
+ timeout_seconds=audit_timeout_seconds,
+ )
+ return
+
+ # ---------- Premium path: reserve → execute → finalize ----------
+ if quota_reserve_micros_override is not None:
+ reserve_micros = max(1, int(quota_reserve_micros_override))
+ else:
+ reserve_micros = estimate_call_reserve_micros(
+ base_model=base_model or "",
+ quota_reserve_tokens=quota_reserve_tokens,
+ )
+
+ request_id = str(uuid4())
+
+ async with session_factory() as quota_session:
+ reserve_result = await TokenQuotaService.premium_reserve(
+ db_session=quota_session,
+ user_id=user_id,
+ request_id=request_id,
+ reserve_micros=reserve_micros,
+ )
+
+ if not reserve_result.allowed:
+ logger.info(
+ "[billable_call] reserve DENIED user=%s usage_type=%s "
+ "reserve=%d used=%d limit=%d remaining=%d",
+ user_id,
+ usage_type,
+ reserve_micros,
+ reserve_result.used,
+ reserve_result.limit,
+ reserve_result.remaining,
+ )
+ raise QuotaInsufficientError(
+ usage_type=usage_type,
+ used_micros=reserve_result.used,
+ limit_micros=reserve_result.limit,
+ remaining_micros=reserve_result.remaining,
+ )
+
+ logger.info(
+ "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
+ "(remaining=%d)",
+ user_id,
+ usage_type,
+ reserve_micros,
+ reserve_result.remaining,
+ )
+
+ try:
+ yield acc
+ except BaseException:
+ # Release on any failure (including QuotaInsufficientError raised
+ # from a downstream call, asyncio cancellation, etc.). We use
+ # BaseException so cancellation also releases.
+ try:
+ async with session_factory() as quota_session:
+ await TokenQuotaService.premium_release(
+ db_session=quota_session,
+ user_id=user_id,
+ reserved_micros=reserve_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] premium_release failed for user=%s "
+ "reserve_micros=%d (reservation will be GC'd by quota "
+ "reconciliation if/when implemented)",
+ user_id,
+ reserve_micros,
+ )
+ raise
+
+ # ---------- Success: finalize + audit ----------------------------
+ actual_micros = acc.total_cost_micros
+ try:
+ logger.info(
+ "[billable_call] finalize start user=%s usage_type=%s actual=%d "
+ "reserved=%d thread=%s",
+ user_id,
+ usage_type,
+ actual_micros,
+ reserve_micros,
+ thread_id,
+ )
+ async with session_factory() as quota_session:
+ final_result = await TokenQuotaService.premium_finalize(
+ db_session=quota_session,
+ user_id=user_id,
+ request_id=request_id,
+ actual_micros=actual_micros,
+ reserved_micros=reserve_micros,
+ )
+ logger.info(
+ "[billable_call] finalize user=%s usage_type=%s actual=%d "
+ "reserved=%d → used=%d/%d (remaining=%d)",
+ user_id,
+ usage_type,
+ actual_micros,
+ reserve_micros,
+ final_result.used,
+ final_result.limit,
+ final_result.remaining,
+ )
+ except Exception as finalize_exc:
+ # Last-ditch: if finalize itself fails, we must at least release
+ # so the reservation doesn't leak.
+ logger.exception(
+ "[billable_call] premium_finalize failed for user=%s; "
+ "attempting release",
+ user_id,
+ )
+ try:
+ async with session_factory() as quota_session:
+ await TokenQuotaService.premium_release(
+ db_session=quota_session,
+ user_id=user_id,
+ reserved_micros=reserve_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] release after finalize failure ALSO failed "
+ "for user=%s",
+ user_id,
+ )
+ raise BillingSettlementError(
+ usage_type=usage_type,
+ user_id=user_id,
+ cause=finalize_exc,
+ ) from finalize_exc
+
+ await _record_audit_best_effort(
+ session_factory=session_factory,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=actual_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ audit_label="premium",
+ timeout_seconds=audit_timeout_seconds,
+ )
+
+
+async def _resolve_agent_billing_for_search_space(
+ session: AsyncSession,
+ search_space_id: int,
+ *,
+ thread_id: int | None = None,
+) -> tuple[UUID, str, str]:
+ """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
+ agent LLM.
+
+ Used by Celery tasks (podcast generation, video presentation) to bill the
+ search-space owner's premium credit pool when the agent LLM is premium.
+
+ Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
+
+ - Search space not found / no ``agent_llm_id``: raise ``ValueError``.
+ - **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
+ * ``thread_id`` is set: delegate to
+ ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
+ recurse into the resolved id. Reuses chat's existing pin if present
+ so the same model bills for chat + downstream podcast/video. If the
+ user is not premium-eligible, the pin service auto-restricts to free
+ deployments — denial only happens later in
+ ``billable_call.premium_reserve`` if the pin really is premium and
+ credit ran out mid-flow.
+ * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
+ for any future direct-API path; today both Celery tasks always pass
+ ``thread_id``.
+ - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
+ (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
+ ``base_model = litellm_params.get("base_model") or model_name`` —
+ NOT provider-prefixed, matching chat's cost-map lookup convention.
+ - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
+ ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
+ ``base_model`` from ``litellm_params`` or ``model_name``.
+
+ Note on imports: ``llm_service``, ``auto_model_pin_service``, and
+ ``llm_router_service`` are imported lazily inside the function body to
+ avoid hoisting litellm side-effects (``litellm.callbacks =
+ [token_tracker]``, ``litellm.drop_params``, etc.) into
+ ``billable_calls.py``'s module load path.
+ """
+ from sqlalchemy import select
+
+ from app.db import NewLLMConfig, SearchSpace
+
+ result = await session.execute(
+ select(SearchSpace).where(SearchSpace.id == search_space_id)
+ )
+ search_space = result.scalars().first()
+ if search_space is None:
+ raise ValueError(f"Search space {search_space_id} not found")
+
+ agent_llm_id = search_space.agent_llm_id
+ if agent_llm_id is None:
+ raise ValueError(
+ f"Search space {search_space_id} has no agent_llm_id configured"
+ )
+
+ owner_user_id: UUID = search_space.user_id
+
+ from app.services.auto_model_pin_service import (
+ AUTO_FASTEST_ID,
+ resolve_or_get_pinned_llm_config_id,
+ )
+
+ if agent_llm_id == AUTO_FASTEST_ID:
+ if thread_id is None:
+ return owner_user_id, "free", "auto"
+ try:
+ resolution = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=str(owner_user_id),
+ selected_llm_config_id=AUTO_FASTEST_ID,
+ )
+ except ValueError:
+ logger.warning(
+ "[agent_billing] Auto-mode pin resolution failed for "
+ "search_space=%s thread=%s; falling back to free",
+ search_space_id,
+ thread_id,
+ exc_info=True,
+ )
+ return owner_user_id, "free", "auto"
+ agent_llm_id = resolution.resolved_llm_config_id
+
+ if agent_llm_id < 0:
+ from app.services.llm_service import get_global_llm_config
+
+ cfg = get_global_llm_config(agent_llm_id) or {}
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
+ return owner_user_id, billing_tier, base_model
+
+ nlc_result = await session.execute(
+ select(NewLLMConfig).where(
+ NewLLMConfig.id == agent_llm_id,
+ NewLLMConfig.search_space_id == search_space_id,
+ )
+ )
+ nlc = nlc_result.scalars().first()
+ base_model = ""
+ if nlc is not None:
+ litellm_params = nlc.litellm_params or {}
+ base_model = litellm_params.get("base_model") or nlc.model_name or ""
+ return owner_user_id, "free", base_model
+
+
+__all__ = [
+ "BillingSettlementError",
+ "QuotaInsufficientError",
+ "_resolve_agent_billing_for_search_space",
+ "billable_call",
+]
+
+
+# Re-export the config knob so callers don't have to import config just for
+# the default image reserve.
+DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py
index a8abe4aa8..edfab1d15 100644
--- a/surfsense_backend/app/services/composio_service.py
+++ b/surfsense_backend/app/services/composio_service.py
@@ -408,12 +408,37 @@ class ComposioService:
files = []
next_token = None
if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ response_data = (
+ inner_data.get("response_data", {})
+ if isinstance(inner_data, dict)
+ else {}
+ )
# Try direct access first, then nested
- files = data.get("files", []) or data.get("data", {}).get("files", [])
+ files = (
+ data.get("files", [])
+ or (
+ inner_data.get("files", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("files", [])
+ )
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
- or data.get("data", {}).get("nextPageToken")
+ or (
+ inner_data.get("nextPageToken")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or (
+ inner_data.get("next_page_token")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or response_data.get("nextPageToken")
+ or response_data.get("next_page_token")
)
elif isinstance(data, list):
files = data
@@ -819,24 +844,61 @@ class ComposioService:
next_token = None
result_size_estimate = None
if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ response_data = (
+ inner_data.get("response_data", {})
+ if isinstance(inner_data, dict)
+ else {}
+ )
messages = (
data.get("messages", [])
- or data.get("data", {}).get("messages", [])
+ or (
+ inner_data.get("messages", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("messages", [])
or data.get("emails", [])
+ or (
+ inner_data.get("emails", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("emails", [])
)
# Check for pagination token in various possible locations
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
- or data.get("data", {}).get("nextPageToken")
- or data.get("data", {}).get("next_page_token")
+ or (
+ inner_data.get("nextPageToken")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or (
+ inner_data.get("next_page_token")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or response_data.get("nextPageToken")
+ or response_data.get("next_page_token")
)
# Extract resultSizeEstimate if available (Gmail API provides this)
result_size_estimate = (
data.get("resultSizeEstimate")
or data.get("result_size_estimate")
- or data.get("data", {}).get("resultSizeEstimate")
- or data.get("data", {}).get("result_size_estimate")
+ or (
+ inner_data.get("resultSizeEstimate")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or (
+ inner_data.get("result_size_estimate")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or response_data.get("resultSizeEstimate")
+ or response_data.get("result_size_estimate")
)
elif isinstance(data, list):
messages = data
@@ -864,7 +926,7 @@ class ComposioService:
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
- tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
+ tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
params={"message_id": message_id}, # snake_case
entity_id=entity_id,
)
@@ -872,7 +934,13 @@ class ComposioService:
if not result.get("success"):
return None, result.get("error", "Unknown error")
- return result.get("data"), None
+ data = result.get("data")
+ if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ return inner_data.get("response_data", inner_data), None
+
+ return data, None
except Exception as e:
logger.error(f"Failed to get Gmail message detail: {e!s}")
@@ -928,10 +996,27 @@ class ComposioService:
# Try different possible response structures
events = []
if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ response_data = (
+ inner_data.get("response_data", {})
+ if isinstance(inner_data, dict)
+ else {}
+ )
events = (
data.get("items", [])
- or data.get("data", {}).get("items", [])
+ or (
+ inner_data.get("items", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("items", [])
or data.get("events", [])
+ or (
+ inner_data.get("events", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("events", [])
)
elif isinstance(data, list):
events = data
diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py
index 7c55da2e5..45bcfd00f 100644
--- a/surfsense_backend/app/services/connector_service.py
+++ b/surfsense_backend/app/services/connector_service.py
@@ -1,6 +1,8 @@
import asyncio
+import os
import time
from datetime import datetime
+from threading import Lock
from typing import Any
import httpx
@@ -2769,12 +2771,22 @@ class ConnectorService:
"""
Get all available (enabled) connector types for a search space.
+ Phase 1.4: results are cached per ``search_space_id`` for
+ :data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
+ identity — the cached value is plain data, safe to share across
+ requests. Invalidate on connector add/update/delete via
+ :func:`invalidate_connector_discovery_cache`.
+
Args:
search_space_id: The search space ID
Returns:
List of SearchSourceConnectorType enums for enabled connectors
"""
+ cached = _get_cached_connectors(search_space_id)
+ if cached is not None:
+ return list(cached)
+
query = (
select(SearchSourceConnector.connector_type)
.filter(
@@ -2784,8 +2796,9 @@ class ConnectorService:
)
result = await self.session.execute(query)
- connector_types = result.scalars().all()
- return list(connector_types)
+ connector_types = list(result.scalars().all())
+ _set_cached_connectors(search_space_id, connector_types)
+ return connector_types
async def get_available_document_types(
self,
@@ -2794,12 +2807,22 @@ class ConnectorService:
"""
Get all document types that have at least one document in the search space.
+ Phase 1.4: cached per ``search_space_id`` for
+ :data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
+ :func:`invalidate_connector_discovery_cache` when a connector
+ finishes indexing new documents (or document types are otherwise
+ added/removed).
+
Args:
search_space_id: The search space ID
Returns:
List of document type strings that have documents indexed
"""
+ cached = _get_cached_doc_types(search_space_id)
+ if cached is not None:
+ return list(cached)
+
from sqlalchemy import distinct
from app.db import Document
@@ -2809,5 +2832,164 @@ class ConnectorService:
)
result = await self.session.execute(query)
- doc_types = result.scalars().all()
- return [str(dt) for dt in doc_types]
+ doc_types = [str(dt) for dt in result.scalars().all()]
+ _set_cached_doc_types(search_space_id, doc_types)
+ return doc_types
+
+
+# ---------------------------------------------------------------------------
+# Connector / document-type discovery TTL cache (Phase 1.4)
+# ---------------------------------------------------------------------------
+#
+# Both ``get_available_connectors`` and ``get_available_document_types`` are
+# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
+# hits Postgres and contributes to per-turn agent build latency. Their
+# results change infrequently — only when the user adds/edits/removes a
+# connector, or when an indexer commits a new document type. A short TTL
+# cache (default 30s, env-tunable) collapses N concurrent calls into one
+# DB roundtrip with bounded staleness.
+#
+# Invalidation: connector mutation routes (create / update / delete) call
+# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
+# entry for the affected space. Multi-replica deployments still pay one
+# DB roundtrip per replica per TTL window, which is fine — staleness is
+# bounded and the alternative (cross-replica fanout) is not worth the
+# coupling here.
+
+_DISCOVERY_TTL_SECONDS: float = float(
+ os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
+)
+
+# Per-search-space caches. Keyed by ``search_space_id``; value is
+# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
+# read-mostly workload, sub-microsecond contention.
+_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
+_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
+_cache_lock = Lock()
+
+
+def _get_cached_connectors(
+ search_space_id: int,
+) -> list[SearchSourceConnectorType] | None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return None
+ with _cache_lock:
+ entry = _connectors_cache.get(search_space_id)
+ if entry is None:
+ return None
+ expires_at, payload = entry
+ if time.monotonic() >= expires_at:
+ _connectors_cache.pop(search_space_id, None)
+ return None
+ return payload
+
+
+def _set_cached_connectors(
+ search_space_id: int, payload: list[SearchSourceConnectorType]
+) -> None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return
+ expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
+ with _cache_lock:
+ _connectors_cache[search_space_id] = (expires_at, list(payload))
+
+
+def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return None
+ with _cache_lock:
+ entry = _doc_types_cache.get(search_space_id)
+ if entry is None:
+ return None
+ expires_at, payload = entry
+ if time.monotonic() >= expires_at:
+ _doc_types_cache.pop(search_space_id, None)
+ return None
+ return payload
+
+
+def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return
+ expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
+ with _cache_lock:
+ _doc_types_cache[search_space_id] = (expires_at, list(payload))
+
+
+def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
+ """Drop cached discovery results for ``search_space_id`` (or all spaces).
+
+ Connector CRUD routes / indexer pipelines call this when they mutate
+ the rows backing :func:`ConnectorService.get_available_connectors` /
+ :func:`get_available_document_types`. ``None`` clears every space —
+ useful in tests and on bulk imports.
+ """
+ with _cache_lock:
+ if search_space_id is None:
+ _connectors_cache.clear()
+ _doc_types_cache.clear()
+ else:
+ _connectors_cache.pop(search_space_id, None)
+ _doc_types_cache.pop(search_space_id, None)
+
+
+def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
+ with _cache_lock:
+ if search_space_id is None:
+ _connectors_cache.clear()
+ else:
+ _connectors_cache.pop(search_space_id, None)
+
+
+def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
+ with _cache_lock:
+ if search_space_id is None:
+ _doc_types_cache.clear()
+ else:
+ _doc_types_cache.pop(search_space_id, None)
+
+
+def _register_invalidation_listeners() -> None:
+ """Wire SQLAlchemy ORM events so cache stays consistent automatically.
+
+ Listening on ``after_insert`` / ``after_update`` / ``after_delete``
+ means every successful INSERT/UPDATE/DELETE that goes through the ORM
+ invalidates the affected search space's cached discovery payload —
+ no need to sprinkle ``invalidate_*`` calls across 30+ connector
+ routes. Bulk operations that bypass the ORM (e.g.
+ ``session.execute(insert(...))`` without a mapped object) still need
+ explicit invalidation; document indexers already commit through the
+ ORM so document-type discovery is covered.
+ """
+ from sqlalchemy import event
+
+ # Imported here (not at module top) to avoid a circular import:
+ # app.services.connector_service is itself imported from app.db's
+ # ecosystem indirectly via several CRUD modules.
+ from app.db import Document, SearchSourceConnector
+
+ def _connector_changed(_mapper, _connection, target) -> None:
+ sid = getattr(target, "search_space_id", None)
+ if sid is not None:
+ _invalidate_connectors_only(int(sid))
+
+ def _document_changed(_mapper, _connection, target) -> None:
+ sid = getattr(target, "search_space_id", None)
+ if sid is not None:
+ _invalidate_doc_types_only(int(sid))
+
+ for evt in ("after_insert", "after_update", "after_delete"):
+ event.listen(SearchSourceConnector, evt, _connector_changed)
+ event.listen(Document, evt, _document_changed)
+
+
+try:
+ _register_invalidation_listeners()
+except Exception: # pragma: no cover - defensive; never block module import
+ import logging as _logging
+
+ _logging.getLogger(__name__).exception(
+ "Failed to register connector discovery cache invalidation listeners; "
+ "stale cache risk: explicit invalidate_connector_discovery_cache calls "
+ "may be required."
+ )
diff --git a/surfsense_backend/app/services/gmail/tool_metadata_service.py b/surfsense_backend/app/services/gmail/tool_metadata_service.py
index c903e24af..4855c1cc9 100644
--- a/surfsense_backend/app/services/gmail/tool_metadata_service.py
+++ b/surfsense_backend/app/services/gmail/tool_metadata_service.py
@@ -17,7 +17,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
-from app.utils.google_credentials import build_composio_credentials
+from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@@ -78,14 +78,49 @@ class GmailToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
- async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
- if (
+ def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
+ return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- return build_composio_credentials(cca_id)
+ )
+
+ def _get_composio_connected_account_id(
+ self, connector: SearchSourceConnector
+ ) -> str:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ raise ValueError("Composio connected_account_id not found")
+ return cca_id
+
+ def _unwrap_composio_data(self, data: Any) -> Any:
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner)
+ return inner
+ return data
+
+ async def _execute_composio_gmail_tool(
+ self,
+ connector: SearchSourceConnector,
+ tool_name: str,
+ params: dict[str, Any],
+ ) -> tuple[Any, str | None]:
+ result = await ComposioService().execute_tool(
+ connected_account_id=self._get_composio_connected_account_id(connector),
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{connector.user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Gmail error")
+ return self._unwrap_composio_data(result.get("data")), None
+
+ async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
+ if self._is_composio_connector(connector):
+ raise ValueError(
+ "Composio Gmail connectors must use Composio tool execution"
+ )
config_data = dict(connector.config)
@@ -139,6 +174,12 @@ class GmailToolMetadataService:
if not connector:
return True
+ if self._is_composio_connector(connector):
+ _profile, error = await self._execute_composio_gmail_tool(
+ connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
+ )
+ return bool(error)
+
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
@@ -221,14 +262,21 @@ class GmailToolMetadataService:
)
connector = result.scalar_one_or_none()
if connector:
- creds = await self._build_credentials(connector)
- service = build("gmail", "v1", credentials=creds)
- profile = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda service=service: (
- service.users().getProfile(userId="me").execute()
- ),
- )
+ if self._is_composio_connector(connector):
+ profile, error = await self._execute_composio_gmail_tool(
+ connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ creds = await self._build_credentials(connector)
+ service = build("gmail", "v1", credentials=creds)
+ profile = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda service=service: (
+ service.users().getProfile(userId="me").execute()
+ ),
+ )
acc_dict["email"] = profile.get("emailAddress", "")
except Exception:
logger.warning(
@@ -298,6 +346,23 @@ class GmailToolMetadataService:
Returns ``None`` on any failure so callers can degrade gracefully.
"""
try:
+ if self._is_composio_connector(connector):
+ if not draft_id:
+ draft_id = await self._find_composio_draft_id(connector, message_id)
+ if not draft_id:
+ return None
+
+ draft, error = await self._execute_composio_gmail_tool(
+ connector,
+ "GMAIL_GET_DRAFT",
+ {"user_id": "me", "draft_id": draft_id, "format": "full"},
+ )
+ if error or not isinstance(draft, dict):
+ return None
+
+ payload = draft.get("message", {}).get("payload", {})
+ return self._extract_body_from_payload(payload)
+
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
@@ -326,6 +391,33 @@ class GmailToolMetadataService:
)
return None
+ async def _find_composio_draft_id(
+ self, connector: SearchSourceConnector, message_id: str
+ ) -> str | None:
+ page_token = ""
+ while True:
+ params: dict[str, Any] = {
+ "user_id": "me",
+ "max_results": 100,
+ "verbose": False,
+ }
+ if page_token:
+ params["page_token"] = page_token
+
+ data, error = await self._execute_composio_gmail_tool(
+ connector, "GMAIL_LIST_DRAFTS", params
+ )
+ if error or not isinstance(data, dict):
+ return None
+
+ for draft in data.get("drafts", []):
+ if draft.get("message", {}).get("id") == message_id:
+ return draft.get("id")
+
+ page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
+ if not page_token:
+ return None
+
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
"""Resolve a draft ID from its message ID by scanning drafts.list."""
try:
diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py
index 20426f3bc..602a55738 100644
--- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py
+++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py
@@ -14,6 +14,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
+from app.services.composio_service import ComposioService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@@ -21,7 +22,6 @@ from app.utils.document_converters import (
generate_document_summary,
generate_unique_identifier_hash,
)
-from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
@@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService:
logger.warning("Document %s not found in KB", document_id)
return {"status": "not_indexed"}
- creds = await self._build_credentials_for_connector(connector_id)
- loop = asyncio.get_event_loop()
- service = await loop.run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
-
calendar_id = (document.document_metadata or {}).get(
"calendar_id"
) or "primary"
- live_event = await loop.run_in_executor(
- None,
- lambda: (
- service.events()
- .get(calendarId=calendar_id, eventId=event_id)
- .execute()
- ),
- )
+ connector = await self._get_connector(connector_id)
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ ):
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ raise ValueError("Composio connected_account_id not found")
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_EVENTS_GET",
+ params={"calendar_id": calendar_id, "event_id": event_id},
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get("error", "Unknown Composio Calendar error")
+ )
+ live_event = composio_result.get("data", {})
+ if isinstance(live_event, dict):
+ live_event = live_event.get("data", live_event)
+ if isinstance(live_event, dict):
+ live_event = live_event.get("response_data", live_event)
+ else:
+ creds = await self._build_credentials_for_connector(connector_id)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ live_event = await loop.run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .get(calendarId=calendar_id, eventId=event_id)
+ .execute()
+ ),
+ )
event_summary = live_event.get("summary", "")
description = live_event.get("description", "")
@@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
- async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
+ async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
result = await self.db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
@@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
connector = result.scalar_one_or_none()
if not connector:
raise ValueError(f"Connector {connector_id} not found")
+ return connector
+ async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
+ connector = await self._get_connector(connector_id)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- return build_composio_credentials(cca_id)
- raise ValueError("Composio connected_account_id not found")
+ raise ValueError(
+ "Composio Calendar connectors must use Composio tool execution"
+ )
config_data = dict(connector.config)
diff --git a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py
index c7bfe1d50..7e50ab039 100644
--- a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py
+++ b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py
@@ -16,7 +16,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
-from app.utils.google_credentials import build_composio_credentials
+from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
- async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
- if (
+ def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
+ return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- return build_composio_credentials(cca_id)
+ )
+
+ def _get_composio_connected_account_id(
+ self, connector: SearchSourceConnector
+ ) -> str:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
raise ValueError("Composio connected_account_id not found")
+ return cca_id
+
+ async def _execute_composio_calendar_tool(
+ self,
+ connector: SearchSourceConnector,
+ tool_name: str,
+ params: dict,
+ ) -> tuple[dict | list | None, str | None]:
+ service = ComposioService()
+ result = await service.execute_tool(
+ connected_account_id=self._get_composio_connected_account_id(connector),
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{connector.user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Calendar error")
+
+ data = result.get("data")
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner), None
+ return inner, None
+ return data, None
+
+ async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
+ if self._is_composio_connector(connector):
+ raise ValueError(
+ "Composio Calendar connectors must use Composio tool execution"
+ )
config_data = dict(connector.config)
@@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
if not connector:
return True
+ if self._is_composio_connector(connector):
+ _data, error = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_GET_CALENDAR",
+ {"calendar_id": "primary"},
+ )
+ return bool(error)
+
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
@@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService:
timezone_str = ""
if connector:
try:
- creds = await self._build_credentials(connector)
- loop = asyncio.get_event_loop()
- service = await loop.run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
+ if self._is_composio_connector(connector):
+ cal_list, cal_error = await self._execute_composio_calendar_tool(
+ connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
+ )
+ if cal_error:
+ raise RuntimeError(cal_error)
+ (
+ settings,
+ settings_error,
+ ) = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_SETTINGS_GET",
+ {"setting": "timezone"},
+ )
+ if not settings_error and isinstance(settings, dict):
+ timezone_str = settings.get("value", "")
+ else:
+ creds = await self._build_credentials(connector)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
- cal_list = await loop.run_in_executor(
- None, lambda: service.calendarList().list().execute()
- )
- for cal in cal_list.get("items", []):
+ cal_list = await loop.run_in_executor(
+ None, lambda: service.calendarList().list().execute()
+ )
+
+ tz_setting = await loop.run_in_executor(
+ None,
+ lambda: service.settings().get(setting="timezone").execute(),
+ )
+ timezone_str = tz_setting.get("value", "")
+
+ calendar_items = []
+ if isinstance(cal_list, dict):
+ calendar_items = (
+ cal_list.get("items") or cal_list.get("calendars") or []
+ )
+ elif isinstance(cal_list, list):
+ calendar_items = cal_list
+
+ for cal in calendar_items:
calendars.append(
{
"id": cal.get("id", ""),
@@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
"primary": cal.get("primary", False),
}
)
-
- tz_setting = await loop.run_in_executor(
- None,
- lambda: service.settings().get(setting="timezone").execute(),
- )
- timezone_str = tz_setting.get("value", "")
except Exception:
logger.warning(
"Failed to fetch calendars/timezone for connector %s",
@@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService:
event_dict = event.to_dict()
try:
- creds = await self._build_credentials(connector)
- loop = asyncio.get_event_loop()
- service = await loop.run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
calendar_id = event.calendar_id or "primary"
- live_event = await loop.run_in_executor(
- None,
- lambda: (
- service.events()
- .get(calendarId=calendar_id, eventId=event.event_id)
- .execute()
- ),
- )
+ if self._is_composio_connector(connector):
+ live_event, error = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_EVENTS_GET",
+ {"calendar_id": calendar_id, "event_id": event.event_id},
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ creds = await self._build_credentials(connector)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ live_event = await loop.run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .get(calendarId=calendar_id, eventId=event.event_id)
+ .execute()
+ ),
+ )
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
event_dict["description"] = live_event.get(
@@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService:
) -> dict:
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
if not resolved:
+ live_resolved = await self._resolve_live_event(
+ search_space_id, user_id, event_ref
+ )
+ if not live_resolved:
+ return {
+ "error": (
+ f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
+ "This could mean: (1) the event doesn't exist, "
+ "(2) the event name is different, or "
+ "(3) the connected calendar account cannot access it."
+ )
+ }
+
+ connector, live_event = live_resolved
+ account = GoogleCalendarAccount.from_connector(connector)
+ acc_dict = account.to_dict()
+ auth_expired = await self._check_account_health(connector.id)
+ acc_dict["auth_expired"] = auth_expired
+ if auth_expired:
+ await self._persist_auth_expired(connector.id)
+
return {
- "error": (
- f"Event '{event_ref}' not found in your indexed Google Calendar events. "
- "This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
- "or (3) the event name is different."
- )
+ "account": acc_dict,
+ "event": self._event_dict_from_live_event(live_event),
}
document, connector = resolved
@@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
if row:
return row[0], row[1]
return None
+
+ async def _resolve_live_event(
+ self, search_space_id: int, user_id: str, event_ref: str
+ ) -> tuple[SearchSourceConnector, dict] | None:
+ result = await self._db_session.execute(
+ select(SearchSourceConnector)
+ .filter(
+ and_(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
+ )
+ )
+ .order_by(SearchSourceConnector.last_indexed_at.desc())
+ )
+ connectors = result.scalars().all()
+
+ for connector in connectors:
+ try:
+ events = await self._search_live_events(connector, event_ref)
+ except Exception:
+ logger.warning(
+ "Failed to search live calendar events for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ continue
+
+ if not events:
+ continue
+
+ normalized_ref = event_ref.strip().lower()
+ exact_match = next(
+ (
+ event
+ for event in events
+ if event.get("summary", "").strip().lower() == normalized_ref
+ ),
+ None,
+ )
+ return connector, exact_match or events[0]
+
+ return None
+
+ async def _search_live_events(
+ self, connector: SearchSourceConnector, event_ref: str
+ ) -> list[dict]:
+ if self._is_composio_connector(connector):
+ data, error = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_EVENTS_LIST",
+ {
+ "calendar_id": "primary",
+ "q": event_ref,
+ "max_results": 10,
+ "single_events": True,
+ "order_by": "startTime",
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if isinstance(data, dict):
+ return data.get("items") or data.get("events") or []
+ return data if isinstance(data, list) else []
+
+ creds = await self._build_credentials(connector)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ response = await loop.run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .list(
+ calendarId="primary",
+ q=event_ref,
+ maxResults=10,
+ singleEvents=True,
+ orderBy="startTime",
+ )
+ .execute()
+ ),
+ )
+ return response.get("items", [])
+
+ def _event_dict_from_live_event(self, event: dict) -> dict:
+ start_data = event.get("start", {})
+ end_data = event.get("end", {})
+ return {
+ "event_id": event.get("id", ""),
+ "summary": event.get("summary", "No Title"),
+ "start": start_data.get("dateTime", start_data.get("date", "")),
+ "end": end_data.get("dateTime", end_data.get("date", "")),
+ "description": event.get("description", ""),
+ "location": event.get("location", ""),
+ "attendees": [
+ {
+ "email": attendee.get("email", ""),
+ "responseStatus": attendee.get("responseStatus", ""),
+ }
+ for attendee in event.get("attendees", [])
+ ],
+ "calendar_id": event.get("calendarId", "primary"),
+ "document_id": None,
+ "indexed_at": None,
+ }
diff --git a/surfsense_backend/app/services/google_drive/tool_metadata_service.py b/surfsense_backend/app/services/google_drive/tool_metadata_service.py
index 221bee14a..0f654bc78 100644
--- a/surfsense_backend/app/services/google_drive/tool_metadata_service.py
+++ b/surfsense_backend/app/services/google_drive/tool_metadata_service.py
@@ -13,7 +13,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
-from app.utils.google_credentials import build_composio_credentials
+from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
+ def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
+ return (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
+ )
+
+ def _get_composio_connected_account_id(
+ self, connector: SearchSourceConnector
+ ) -> str:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ raise ValueError("Composio connected_account_id not found")
+ return cca_id
+
+ async def _execute_composio_drive_tool(
+ self,
+ connector: SearchSourceConnector,
+ tool_name: str,
+ params: dict,
+ ) -> tuple[dict | list | None, str | None]:
+ result = await ComposioService().execute_tool(
+ connected_account_id=self._get_composio_connected_account_id(connector),
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{connector.user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Drive error")
+ data = result.get("data")
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner), None
+ return inner, None
+ return data, None
+
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
@@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
if not connector:
return True
- pre_built_creds = None
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
+ if self._is_composio_connector(connector):
+ _data, error = await self._execute_composio_drive_tool(
+ connector,
+ "GOOGLEDRIVE_LIST_FILES",
+ {
+ "q": "trashed = false",
+ "page_size": 1,
+ "fields": "files(id)",
+ },
+ )
+ return bool(error)
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
- credentials=pre_built_creds,
)
await client.list_files(
query="trashed = false", page_size=1, fields="files(id)"
@@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
parent_folders[connector_id] = []
continue
- pre_built_creds = None
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
+ if self._is_composio_connector(connector):
+ data, error = await self._execute_composio_drive_tool(
+ connector,
+ "GOOGLEDRIVE_LIST_FILES",
+ {
+ "q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
+ "fields": "files(id,name)",
+ "page_size": 50,
+ },
+ )
+ if error:
+ logger.warning(
+ "Failed to list folders for connector %s: %s",
+ connector_id,
+ error,
+ )
+ parent_folders[connector_id] = []
+ continue
+ folders = []
+ if isinstance(data, dict):
+ folders = data.get("files", [])
+ elif isinstance(data, list):
+ folders = data
+ parent_folders[connector_id] = [
+ {"folder_id": f["id"], "name": f["name"]}
+ for f in folders
+ if f.get("id") and f.get("name")
+ ]
+ continue
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
- credentials=pre_built_creds,
)
folders, _, error = await client.list_files(
diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py
index f45a6ab63..b4de2a0bf 100644
--- a/surfsense_backend/app/services/image_gen_router_service.py
+++ b/surfsense_backend/app/services/image_gen_router_service.py
@@ -20,6 +20,8 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
+from app.services.provider_api_base import resolve_api_base
+
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
@@ -152,12 +154,12 @@ class ImageGenRouterService:
return None
# Build model string
+ provider = config.get("provider", "").upper()
if config.get("custom_provider"):
- model_string = f"{config['custom_provider']}/{config['model_name']}"
+ provider_prefix = config["custom_provider"]
else:
- provider = config.get("provider", "").upper()
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
- model_string = f"{provider_prefix}/{config['model_name']}"
+ model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
@@ -165,9 +167,16 @@ class ImageGenRouterService:
"api_key": config.get("api_key"),
}
- # Add optional api_base
- if config.get("api_base"):
- litellm_params["api_base"] = config["api_base"]
+ # Resolve ``api_base`` so deployments don't silently inherit
+ # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
+ # the wrong provider (see ``provider_api_base`` docstring).
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
+ if api_base:
+ litellm_params["api_base"] = api_base
# Add api_version (required for Azure)
if config.get("api_version"):
diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py
index 4bce79a43..d220aa346 100644
--- a/surfsense_backend/app/services/llm_router_service.py
+++ b/surfsense_backend/app/services/llm_router_service.py
@@ -28,6 +28,7 @@ from litellm.exceptions import (
BadRequestError as LiteLLMBadRequestError,
ContextWindowExceededError,
)
+from pydantic import Field
from app.utils.perf import get_perf_logger
@@ -133,42 +134,14 @@ PROVIDER_MAP = {
}
-# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
-# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
-# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
-# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
-# request to an Azure endpoint, which then 404s with ``Resource not found``.
-# Only providers with a well-known, stable public base URL are listed here —
-# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
-# huggingface, databricks, cloudflare, replicate) are intentionally omitted
-# so their existing config-driven behaviour is preserved.
-PROVIDER_DEFAULT_API_BASE = {
- "openrouter": "https://openrouter.ai/api/v1",
- "groq": "https://api.groq.com/openai/v1",
- "mistral": "https://api.mistral.ai/v1",
- "perplexity": "https://api.perplexity.ai",
- "xai": "https://api.x.ai/v1",
- "cerebras": "https://api.cerebras.ai/v1",
- "deepinfra": "https://api.deepinfra.com/v1/openai",
- "fireworks_ai": "https://api.fireworks.ai/inference/v1",
- "together_ai": "https://api.together.xyz/v1",
- "anyscale": "https://api.endpoints.anyscale.com/v1",
- "cometapi": "https://api.cometapi.com/v1",
- "sambanova": "https://api.sambanova.ai/v1",
-}
-
-
-# Canonical provider → base URL when a config uses a generic ``openai``-style
-# prefix but the ``provider`` field tells us which API it really is
-# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
-# each has its own base URL).
-PROVIDER_KEY_DEFAULT_API_BASE = {
- "DEEPSEEK": "https://api.deepseek.com/v1",
- "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
- "MOONSHOT": "https://api.moonshot.ai/v1",
- "ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
- "MINIMAX": "https://api.minimax.io/v1",
-}
+# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
+# hoisted to ``app.services.provider_api_base`` so vision and image-gen
+# call sites can share the exact same defense (OpenRouter / Groq / etc.
+# 404-ing against an inherited Azure endpoint). Re-exported here for
+# backward compatibility with any external import.
+from app.services.provider_api_base import ( # noqa: E402
+ resolve_api_base,
+)
class LLMRouterService:
@@ -207,6 +180,12 @@ class LLMRouterService:
"""
Initialize the router with global LLM configurations.
+ Configs with ``router_pool_eligible=False`` are skipped so that
+ dynamic OpenRouter entries stay out of the shared router pool used
+ by title-gen / sub-agent ``model="auto"`` flows. Those dynamic
+ entries are still available for user-facing Auto-mode thread pinning
+ via ``auto_model_pin_service``.
+
Args:
global_configs: List of global LLM config dictionaries from YAML
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
@@ -220,6 +199,8 @@ class LLMRouterService:
model_list = []
premium_models: set[str] = set()
for config in global_configs:
+ if config.get("router_pool_eligible") is False:
+ continue
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
@@ -308,10 +289,45 @@ class LLMRouterService:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
+ @classmethod
+ def rebuild(
+ cls,
+ global_configs: list[dict],
+ router_settings: dict | None = None,
+ ) -> None:
+ """Reset the router and re-run ``initialize`` with fresh configs.
+
+ ``initialize`` short-circuits once it has run to avoid re-creating the
+ LiteLLM Router on every request; ``rebuild`` deliberately clears
+ ``_initialized`` so a caller (e.g. background OpenRouter refresh)
+ can force the pool to be rebuilt after catalogue changes.
+ """
+ instance = cls.get_instance()
+ instance._initialized = False
+ instance._router = None
+ instance._model_list = []
+ instance._premium_model_strings = set()
+ cls.initialize(global_configs, router_settings)
+
@classmethod
def is_premium_model(cls, model_string: str) -> bool:
- """Return True if *model_string* (as reported by LiteLLM) belongs to a
- premium-tier deployment in the router pool."""
+ """Return True if *model_string* belongs to a premium-tier deployment
+ in the LiteLLM router pool.
+
+ Scope: only covers configs with ``router_pool_eligible`` truthy. That
+ includes static YAML premium configs AND dynamic OpenRouter *premium*
+ entries (which opt in at generation time). Dynamic OpenRouter *free*
+ entries are deliberately kept out of the router pool — OpenRouter
+ enforces free-tier limits globally per account, so per-deployment
+ router accounting can't represent them correctly — and therefore
+ return ``False`` here, which matches their ``billing_tier="free"``
+ (no premium quota).
+
+ For per-request premium checks on an arbitrary config (static or
+ dynamic, pool or non-pool), read ``agent_config.is_premium`` instead;
+ that reflects the per-config ``billing_tier`` directly and is what
+ user-facing Auto-mode thread pinning uses to bill correctly.
+ """
instance = cls.get_instance()
return model_string in instance._premium_model_strings
@@ -422,14 +438,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
- # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
+ # requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
- api_base = config.get("api_base")
- if not api_base:
- api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
- if not api_base:
- api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
if api_base:
litellm_params["api_base"] = api_base
@@ -573,6 +589,11 @@ class ChatLiteLLMRouter(BaseChatModel):
# Public attributes that Pydantic will manage
model: str = "auto"
streaming: bool = True
+ # Static kwargs that flow through to ``litellm.completion(...)`` on every
+ # invocation (e.g. ``cache_control_injection_points`` set by
+ # ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from
+ # ``invoke()`` still take precedence — see ``_generate``/``_astream``.
+ model_kwargs: dict[str, Any] = Field(default_factory=dict)
# Bound tools and tool choice for tool calling
_bound_tools: list[dict] | None = None
@@ -898,13 +919,16 @@ class ChatLiteLLMRouter(BaseChatModel):
logger.warning(f"Failed to convert tool {tool}: {e}")
continue
- # Create a new instance with tools bound
+ # Create a new instance with tools bound. Carry through ``model_kwargs``
+ # so static settings (e.g. cache_control_injection_points) survive the
+ # bind_tools rebuild.
return ChatLiteLLMRouter(
router=self._router,
bound_tools=formatted_tools if formatted_tools else None,
tool_choice=tool_choice,
model=self.model,
streaming=self.streaming,
+ model_kwargs=dict(self.model_kwargs),
**kwargs,
)
@@ -929,8 +953,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
- # Add tools if bound
- call_kwargs = {**kwargs}
+ # Merge static model_kwargs (e.g. cache_control_injection_points) under
+ # per-call kwargs so callers can still override per invocation. Then add
+ # bound tools.
+ call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
@@ -997,8 +1023,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
- # Add tools if bound
- call_kwargs = {**kwargs}
+ # Merge static model_kwargs (e.g. cache_control_injection_points) under
+ # per-call kwargs so callers can still override per invocation. Then add
+ # bound tools.
+ call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
@@ -1060,8 +1088,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
- # Add tools if bound
- call_kwargs = {**kwargs}
+ # Merge static model_kwargs (e.g. cache_control_injection_points) under
+ # per-call kwargs so callers can still override per invocation. Then add
+ # bound tools.
+ call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
@@ -1110,8 +1140,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
- # Add tools if bound
- call_kwargs = {**kwargs}
+ # Merge static model_kwargs (e.g. cache_control_injection_points) under
+ # per-call kwargs so callers can still override per invocation. Then add
+ # bound tools.
+ call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py
index 942a9b7af..ade202c72 100644
--- a/surfsense_backend/app/services/llm_service.py
+++ b/surfsense_backend/app/services/llm_service.py
@@ -16,6 +16,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm,
is_auto_mode,
)
+from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
@@ -496,8 +497,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
+
+ Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
+ so each ``ainvoke`` debits the search-space owner's premium credit
+ pool. User-owned BYOK configs and free global configs are returned
+ unwrapped — they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
@@ -519,6 +526,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
+ owner_user_id = search_space.user_id
+
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@@ -526,6 +535,13 @@ async def get_vision_llm(
)
return None
try:
+ # Auto mode is currently treated as free at the wrapper
+ # level — the underlying router can dispatch to either
+ # premium or free YAML configs but routing decisions are
+ # opaque. If/when we want to bill Auto-routed vision
+ # calls we'd need to thread the resolved deployment's
+ # billing_tier back from the router. For now we keep
+ # parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
@@ -541,29 +557,46 @@ async def get_vision_llm(
return None
if global_cfg.get("custom_provider"):
- model_string = (
- f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
- )
+ provider_prefix = global_cfg["custom_provider"]
+ model_string = f"{provider_prefix}/{global_cfg['model_name']}"
else:
- prefix = VISION_PROVIDER_MAP.get(
+ provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(),
global_cfg["provider"].lower(),
)
- model_string = f"{prefix}/{global_cfg['model_name']}"
+ model_string = f"{provider_prefix}/{global_cfg['model_name']}"
litellm_kwargs = {
"model": model_string,
"api_key": global_cfg["api_key"],
}
- if global_cfg.get("api_base"):
- litellm_kwargs["api_base"] = global_cfg["api_base"]
+ api_base = resolve_api_base(
+ provider=global_cfg.get("provider"),
+ provider_prefix=provider_prefix,
+ config_api_base=global_cfg.get("api_base"),
+ )
+ if api_base:
+ litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
- return SanitizedChatLiteLLM(**litellm_kwargs)
+ inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
+ billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
+ if billing_tier == "premium":
+ return QuotaCheckedVisionLLM(
+ inner_llm,
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=model_string,
+ quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
+ )
+ return inner_llm
+
+ # User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,
@@ -578,20 +611,26 @@ async def get_vision_llm(
return None
if vision_cfg.custom_provider:
- model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
+ provider_prefix = vision_cfg.custom_provider
+ model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else:
- prefix = VISION_PROVIDER_MAP.get(
+ provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(),
)
- model_string = f"{prefix}/{vision_cfg.model_name}"
+ model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = {
"model": model_string,
"api_key": vision_cfg.api_key,
}
- if vision_cfg.api_base:
- litellm_kwargs["api_base"] = vision_cfg.api_base
+ api_base = resolve_api_base(
+ provider=vision_cfg.provider.value,
+ provider_prefix=provider_prefix,
+ config_api_base=vision_cfg.api_base,
+ )
+ if api_base:
+ litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py
index 3531d37af..55129668c 100644
--- a/surfsense_backend/app/services/new_streaming_service.py
+++ b/surfsense_backend/app/services/new_streaming_service.py
@@ -565,20 +565,31 @@ class VercelStreamingService:
# Error Part
# =========================================================================
- def format_error(self, error_text: str) -> str:
+ def format_error(
+ self,
+ error_text: str,
+ error_code: str | None = None,
+ extra: dict[str, object] | None = None,
+ ) -> str:
"""
Format an error message.
Args:
error_text: The error message text
+ error_code: Optional machine-readable error code for frontend branching
Returns:
str: SSE formatted error part
Example output:
- data: {"type":"error","errorText":"Something went wrong"}
+ data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
"""
- return self._format_sse({"type": "error", "errorText": error_text})
+ payload: dict[str, object] = {"type": "error", "errorText": error_text}
+ if error_code:
+ payload["errorCode"] = error_code
+ if extra:
+ payload.update(extra)
+ return self._format_sse(payload)
# =========================================================================
# Tool Parts
diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py
index 1245f73aa..6454e2d58 100644
--- a/surfsense_backend/app/services/openrouter_integration_service.py
+++ b/surfsense_backend/app/services/openrouter_integration_service.py
@@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path.
"""
import asyncio
+import hashlib
import logging
import threading
+import time
from typing import Any
import httpx
+from app.services.quality_score import (
+ _HEALTH_BLEND_WEIGHT,
+ _HEALTH_ENRICH_CONCURRENCY,
+ _HEALTH_ENRICH_TOP_N_FREE,
+ _HEALTH_ENRICH_TOP_N_PREMIUM,
+ _HEALTH_FAIL_RATIO_FALLBACK,
+ _HEALTH_FETCH_TIMEOUT_SEC,
+ aggregate_health,
+ static_score_or,
+)
+
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
+OPENROUTER_ENDPOINTS_URL_TEMPLATE = (
+ "https://openrouter.ai/api/v1/models/{model_id}/endpoints"
+)
# Sentinel value stored on each generated config so we can distinguish
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
+# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides
+# enough headroom to avoid frequent collisions for OpenRouter's catalogue
+# (~300 models) while keeping IDs comfortably within Postgres INTEGER range.
+_STABLE_ID_HASH_WIDTH = 9_000_000
+
+
+def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int:
+ """Derive a deterministic negative config ID from ``model_id``.
+
+ The same ``model_id`` always hashes to the same base value so thread pins
+ survive catalogue churn (models appearing/disappearing/reordering between
+ refreshes). On collision we decrement until we find an unused slot; this
+ keeps the mapping stable for the first config that claimed a slot and
+ only shifts collisions, which is much less disruptive than the legacy
+ index-based scheme that reshuffled every ID when the catalogue changed.
+ """
+ digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest()
+ base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH)
+ cid = base
+ while cid in taken:
+ cid -= 1
+ taken.add(cid)
+ return cid
+
+
+def _openrouter_tier(model: dict) -> str:
+ """Classify an OpenRouter model as ``"free"`` or ``"premium"``.
+
+ Per OpenRouter's API contract, a model is free if:
+ - Its id ends with ``:free`` (OpenRouter's own free-variant convention), or
+ - Both ``pricing.prompt`` and ``pricing.completion`` are zero strings.
+
+ Anything else (missing pricing, non-zero pricing) falls through to
+ ``"premium"`` so we never under-charge users. This derivation runs off the
+ already-cached /api/v1/models payload, so it adds no network cost.
+ """
+ if model.get("id", "").endswith(":free"):
+ return "free"
+ pricing = model.get("pricing") or {}
+ prompt = str(pricing.get("prompt", "")).strip()
+ completion = str(pricing.get("completion", "")).strip()
+ if prompt == "0" and completion == "0":
+ return "free"
+ return "premium"
+
def _is_text_output_model(model: dict) -> bool:
"""Return True if the model produces text output only (skip image/audio generators)."""
@@ -32,6 +93,53 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
+def _is_image_output_model(model: dict) -> bool:
+ """Return True if the model can produce image output.
+
+ OpenRouter's ``architecture.output_modalities`` is a list (e.g.
+ ``["image"]`` for pure image generators, ``["text", "image"]`` for
+ multi-modal generators that also emit captions). We accept any model
+ that can output images; the call site decides whether to use the
+ image-generation API or chat completion.
+ """
+ output_mods = model.get("architecture", {}).get("output_modalities", []) or []
+ return "image" in output_mods
+
+
+def _is_vision_input_model(model: dict) -> bool:
+ """Return True if the model can ingest an image AND emit text.
+
+ OpenRouter's ``architecture.input_modalities`` lists what the model
+ accepts; ``output_modalities`` lists what it produces. A vision LLM
+ is a model that takes images in and produces text out — i.e. it can
+ answer questions about a screenshot or extract content from an
+ image. Pure image-to-image models (e.g. style transfer) and
+ text-only models are excluded.
+ """
+ arch = model.get("architecture", {}) or {}
+ input_mods = arch.get("input_modalities", []) or []
+ output_mods = arch.get("output_modalities", []) or []
+ return "image" in input_mods and "text" in output_mods
+
+
+def _supports_image_input(model: dict) -> bool:
+ """Return True if the model accepts ``image`` in its input modalities.
+
+ Differs from :func:`_is_vision_input_model` in that it does NOT
+ require text output — chat-tab models always emit text already (the
+ chat catalog filters by ``_is_text_output_model``), so the only
+ extra capability we need to track per chat config is whether the
+ model can ingest user-attached images. The chat selector and the
+ streaming task both key off this flag to prevent hitting an
+ OpenRouter 404 ``"No endpoints found that support image input"``
+ when the user uploads an image and selects a text-only model
+ (DeepSeek V3, Llama 3.x base, etc.).
+ """
+ arch = model.get("architecture", {}) or {}
+ input_mods = arch.get("input_modalities", []) or []
+ return "image" in input_mods
+
+
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@@ -56,6 +164,11 @@ _EXCLUDED_MODEL_IDS: set[str] = {
# Deep-research models reject standard params (temperature, etc.)
"openai/o3-deep-research",
"openai/o4-mini-deep-research",
+ # OpenRouter's own meta-router over free models. We already enumerate every
+ # concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread
+ # pinning handles churn via the repair path, so exposing an additional
+ # indirection layer would only duplicate the capability with an opaque slug.
+ "openrouter/free",
}
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
@@ -109,24 +222,71 @@ async def _fetch_models_async() -> list[dict] | None:
return None
+def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
+ """Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
+
+ Pricing values are kept as the raw OpenRouter strings (e.g.
+ ``"0.000003"``); ``pricing_registration`` converts them to floats
+ when registering with LiteLLM. Models with missing or malformed
+ pricing are simply omitted — operator-side risk if any of those are
+ premium.
+ """
+ pricing: dict[str, dict[str, str]] = {}
+ for model in raw_models:
+ model_id = str(model.get("id") or "").strip()
+ if not model_id:
+ continue
+ p = model.get("pricing") or {}
+ prompt = p.get("prompt")
+ completion = p.get("completion")
+ if prompt is None and completion is None:
+ continue
+ pricing[model_id] = {
+ "prompt": str(prompt) if prompt is not None else "",
+ "completion": str(completion) if completion is not None else "",
+ }
+ return pricing
+
+
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
) -> list[dict]:
- """
- Convert raw OpenRouter model entries into global LLM config dicts.
+ """Convert raw OpenRouter model entries into global LLM config dicts.
- Models are sorted by ID for deterministic, stable ID assignment across
- restarts and refreshes.
+ Tier (``billing_tier``) is derived per-model from OpenRouter's own API
+ signals via ``_openrouter_tier`` — there is no longer a uniform YAML
+ override. Config IDs are derived via ``_stable_config_id`` so they
+ survive catalogue churn across refreshes.
+
+ Router-pool membership is tier-aware:
+
+ - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``)
+ so sub-agent ``model="auto"`` flows benefit from load balancing and
+ failover across the curated YAML configs and the OR premium passthrough.
+ - Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM
+ Router tracks rate limits per deployment, but OpenRouter enforces a
+ single global free-tier quota (~20 RPM + 50-1000 daily requests
+ account-wide across every ``:free`` model), so rotating across many
+ free deployments would only burn the shared bucket faster. Free OR
+ models remain fully available for user-facing Auto-mode thread pinning
+ via ``auto_model_pin_service``.
+
+ OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
+ via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
+ because our own Auto (Fastest) pin + 24 h refresh + repair logic already
+ cover the catalogue-churn case.
"""
id_offset: int = settings.get("id_offset", -10000)
api_key: str = settings.get("api_key", "")
- billing_tier: str = settings.get("billing_tier", "premium")
- anonymous_enabled: bool = settings.get("anonymous_enabled", False)
seo_enabled: bool = settings.get("seo_enabled", False)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
rpm: int = settings.get("rpm", 200)
- tpm: int = settings.get("tpm", 1000000)
+ tpm: int = settings.get("tpm", 1_000_000)
+ free_rpm: int = settings.get("free_rpm", 20)
+ free_tpm: int = settings.get("free_tpm", 100_000)
+ anon_paid: bool = settings.get("anonymous_enabled_paid", False)
+ anon_free: bool = settings.get("anonymous_enabled_free", False)
litellm_params: dict = settings.get("litellm_params") or {}
system_instructions: str = settings.get("system_instructions", "")
use_default: bool = settings.get("use_default_system_instructions", True)
@@ -142,19 +302,24 @@ def _generate_configs(
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
- text_models.sort(key=lambda m: m["id"])
configs: list[dict] = []
- for idx, model in enumerate(text_models):
+ taken: set[int] = set()
+ now_ts = int(time.time())
+
+ for model in text_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+
+ static_q = static_score_or(model, now_ts=now_ts)
cfg: dict[str, Any] = {
- "id": id_offset - idx,
+ "id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter",
- "billing_tier": billing_tier,
- "anonymous_enabled": anonymous_enabled,
+ "billing_tier": tier,
+ "anonymous_enabled": anon_free if tier == "free" else anon_paid,
"seo_enabled": seo_enabled,
"seo_slug": None,
"quota_reserve_tokens": quota_reserve_tokens,
@@ -162,12 +327,199 @@ def _generate_configs(
"model_name": model_id,
"api_key": api_key,
"api_base": "",
- "rpm": rpm,
- "tpm": tpm,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"system_instructions": system_instructions,
"use_default_system_instructions": use_default,
"citations_enabled": citations_enabled,
+ # Premium OR deployments join the LiteLLM router pool so sub-agent
+ # model="auto" flows can load-balance / fail over across them.
+ # Free OR deployments stay out: OpenRouter's free tier is a single
+ # account-wide quota, so per-deployment routing can't spread load
+ # there — it just drains the shared bucket faster.
+ "router_pool_eligible": tier == "premium",
+ # Capability flag derived from ``architecture.input_modalities``.
+ # Read by the new-chat selector to dim image-incompatible models
+ # when the user has pending image attachments, and by
+ # ``stream_new_chat`` as a fail-fast safety net before the
+ # OpenRouter request would otherwise 404 with
+ # ``"No endpoints found that support image input"``.
+ "supports_image_input": _supports_image_input(model),
+ _OPENROUTER_DYNAMIC_MARKER: True,
+ # Auto (Fastest) ranking metadata. ``quality_score`` is initialised
+ # to the static score and gets re-blended with health on the next
+ # ``_enrich_health`` pass (synchronous on refresh, deferred on cold
+ # start so startup latency is unchanged).
+ "auto_pin_tier": "B" if tier == "premium" else "C",
+ "quality_score_static": static_q,
+ "quality_score_health": None,
+ "quality_score": static_q,
+ "health_gated": False,
+ }
+ configs.append(cfg)
+
+ return configs
+
+
+# ID-offset bands used to keep dynamic OpenRouter configs in their own
+# namespace per surface. Image / vision get separate bands so a single
+# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
+_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
+_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
+
+
+def _generate_image_gen_configs(
+ raw_models: list[dict], settings: dict[str, Any]
+) -> list[dict]:
+ """Convert OpenRouter image-generation models into global image-gen
+ config dicts (matches the YAML shape consumed by ``image_generation_routes``).
+
+ Filter:
+ - architecture.output_modalities contains "image"
+ - compatible provider (excluded slugs blocked)
+ - allowed model id (excluded list blocked)
+
+ Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
+ ``_has_sufficient_context``) because tool calls and context windows are
+ irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
+ derived per model the same way as chat (``_openrouter_tier``).
+
+ Cost is intentionally *not* registered with LiteLLM at startup
+ (``pricing_registration`` skips image gen): OpenRouter image-gen
+ models are not in LiteLLM's native cost map and OpenRouter populates
+ ``response_cost`` directly from the response header. A defensive
+ branch in ``_extract_cost_usd`` handles the rare case where
+ ``usage.cost`` is missing — see ``token_tracking_service``.
+ """
+ id_offset: int = int(
+ settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
+ )
+ api_key: str = settings.get("api_key", "")
+ rpm: int = settings.get("rpm", 200)
+ free_rpm: int = settings.get("free_rpm", 20)
+ litellm_params: dict = settings.get("litellm_params") or {}
+
+ image_models = [
+ m
+ for m in raw_models
+ if _is_image_output_model(m)
+ and _is_compatible_provider(m)
+ and _is_allowed_model(m)
+ and "/" in m.get("id", "")
+ ]
+
+ configs: list[dict] = []
+ taken: set[int] = set()
+ for model in image_models:
+ model_id: str = model["id"]
+ name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+
+ cfg: dict[str, Any] = {
+ "id": _stable_config_id(model_id, id_offset, taken),
+ "name": name,
+ "description": f"{name} via OpenRouter (image generation)",
+ "provider": "OPENROUTER",
+ "model_name": model_id,
+ "api_key": api_key,
+ # Pin to OpenRouter's public base URL so a downstream call site
+ # that forgets ``resolve_api_base`` still doesn't inherit
+ # ``AZURE_OPENAI_ENDPOINT`` and 404 on
+ # ``image_generation/transformation`` (defense-in-depth, see
+ # ``provider_api_base`` docstring).
+ "api_base": "https://openrouter.ai/api/v1",
+ "api_version": None,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "litellm_params": dict(litellm_params),
+ "billing_tier": tier,
+ _OPENROUTER_DYNAMIC_MARKER: True,
+ }
+ configs.append(cfg)
+
+ return configs
+
+
+def _generate_vision_llm_configs(
+ raw_models: list[dict], settings: dict[str, Any]
+) -> list[dict]:
+ """Convert OpenRouter vision-capable LLMs into global vision-LLM config
+ dicts (matches the YAML shape consumed by ``vision_llm_routes``).
+
+ Filter:
+ - architecture.input_modalities contains "image"
+ - architecture.output_modalities contains "text"
+ - compatible provider (excluded slugs blocked)
+ - allowed model id (excluded list blocked)
+
+ Vision-LLM is invoked from the indexer (image extraction during
+ document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
+ the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
+ filters do not apply: a small-context vision model that doesn't
+ advertise tool-calling is still perfectly viable for "describe this
+ image" prompts.
+ """
+ id_offset: int = int(
+ settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
+ )
+ api_key: str = settings.get("api_key", "")
+ rpm: int = settings.get("rpm", 200)
+ tpm: int = settings.get("tpm", 1_000_000)
+ free_rpm: int = settings.get("free_rpm", 20)
+ free_tpm: int = settings.get("free_tpm", 100_000)
+ quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
+ litellm_params: dict = settings.get("litellm_params") or {}
+
+ vision_models = [
+ m
+ for m in raw_models
+ if _is_vision_input_model(m)
+ and _is_compatible_provider(m)
+ and _is_allowed_model(m)
+ and "/" in m.get("id", "")
+ ]
+
+ configs: list[dict] = []
+ taken: set[int] = set()
+ for model in vision_models:
+ model_id: str = model["id"]
+ name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+ pricing = model.get("pricing") or {}
+
+ # Capture per-token prices so ``pricing_registration`` can
+ # register them with LiteLLM at startup (and so the cost
+ # estimator in ``estimate_call_reserve_micros`` can resolve
+ # them at reserve time).
+ try:
+ input_cost = float(pricing.get("prompt", 0) or 0)
+ except (TypeError, ValueError):
+ input_cost = 0.0
+ try:
+ output_cost = float(pricing.get("completion", 0) or 0)
+ except (TypeError, ValueError):
+ output_cost = 0.0
+
+ cfg: dict[str, Any] = {
+ "id": _stable_config_id(model_id, id_offset, taken),
+ "name": name,
+ "description": f"{name} via OpenRouter (vision)",
+ "provider": "OPENROUTER",
+ "model_name": model_id,
+ "api_key": api_key,
+ # Pin to OpenRouter's public base URL so a downstream call site
+ # that forgets ``resolve_api_base`` still doesn't inherit
+ # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
+ # ``provider_api_base`` docstring).
+ "api_base": "https://openrouter.ai/api/v1",
+ "api_version": None,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "tpm": free_tpm if tier == "free" else tpm,
+ "litellm_params": dict(litellm_params),
+ "billing_tier": tier,
+ "quota_reserve_tokens": quota_reserve_tokens,
+ "input_cost_per_token": input_cost or None,
+ "output_cost_per_token": output_cost or None,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
@@ -187,6 +539,25 @@ class OpenRouterIntegrationService:
self._configs_by_id: dict[int, dict] = {}
self._initialized = False
self._refresh_task: asyncio.Task | None = None
+ # Last-good per-model health snapshot. Survives across refresh
+ # cycles so a transient OpenRouter /endpoints outage doesn't drop
+ # every cfg back to static-only scoring.
+ # Shape: {model_name: {"gated": bool, "score": float | None}}
+ self._health_cache: dict[str, dict[str, Any]] = {}
+ self._enrich_task: asyncio.Task | None = None
+ # Raw OpenRouter pricing per model_id, captured at the same time
+ # we generate configs. Consumed by ``pricing_registration`` to
+ # teach LiteLLM the per-token cost of every dynamic deployment so
+ # the success-callback can populate ``response_cost`` correctly.
+ self._raw_pricing: dict[str, dict[str, str]] = {}
+ # Cached raw catalogue from the most recent fetch. Image / vision
+ # emitters reuse this to avoid a second network call per surface.
+ self._raw_models: list[dict] = []
+ # Image / vision config caches (only populated when the matching
+ # opt-in flag is true on initialize). Refreshed in lockstep with
+ # the chat catalogue.
+ self._image_configs: list[dict] = []
+ self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@@ -216,16 +587,55 @@ class OpenRouterIntegrationService:
self._initialized = True
return []
+ self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
+ self._raw_pricing = _extract_raw_pricing(raw_models)
+
+ # Populate image / vision caches when their opt-in flag is set.
+ # Empty otherwise so the accessors return [] without re-running
+ # filters every refresh.
+ if settings.get("image_generation_enabled"):
+ self._image_configs = _generate_image_gen_configs(raw_models, settings)
+ logger.info(
+ "OpenRouter integration: image-gen emission ON (%d models)",
+ len(self._image_configs),
+ )
+ else:
+ self._image_configs = []
+
+ if settings.get("vision_enabled"):
+ self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
+ logger.info(
+ "OpenRouter integration: vision LLM emission ON (%d models)",
+ len(self._vision_configs),
+ )
+ else:
+ self._vision_configs = []
+
self._initialized = True
+ tier_counts = self._tier_counts(self._configs)
logger.info(
- "OpenRouter integration: loaded %d models (IDs %d to %d)",
+ "OpenRouter integration: loaded %d models (free=%d, premium=%d)",
len(self._configs),
- self._configs[0]["id"] if self._configs else 0,
- self._configs[-1]["id"] if self._configs else 0,
+ tier_counts["free"],
+ tier_counts["premium"],
)
+
+ # Schedule the first health-enrichment pass as a deferred task so
+ # cold-start latency is unchanged. Only valid when an event loop is
+ # already running (e.g. FastAPI lifespan); Celery worker init is
+ # fully sync so we silently skip — its first refresh tick (or the
+ # next refresh from the web process) will populate health data.
+ try:
+ loop = asyncio.get_running_loop()
+ self._enrich_task = loop.create_task(
+ self._enrich_health_safely(self._configs)
+ )
+ except RuntimeError:
+ pass
+
return self._configs
# ------------------------------------------------------------------
@@ -241,6 +651,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
+ self._raw_pricing = _extract_raw_pricing(raw_models)
+ self._raw_models = raw_models
from app.config import config as app_config
@@ -254,7 +666,263 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
- logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
+ # Image / vision lists are atomic-swapped the same way: filter out
+ # the previous dynamic entries from the live config list and append
+ # the freshly generated ones. No-ops when the opt-in flag is off.
+ if self._settings.get("image_generation_enabled"):
+ new_image = _generate_image_gen_configs(raw_models, self._settings)
+ static_image = [
+ c
+ for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
+ if not c.get(_OPENROUTER_DYNAMIC_MARKER)
+ ]
+ app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
+ self._image_configs = new_image
+
+ if self._settings.get("vision_enabled"):
+ new_vision = _generate_vision_llm_configs(raw_models, self._settings)
+ static_vision = [
+ c
+ for c in app_config.GLOBAL_VISION_LLM_CONFIGS
+ if not c.get(_OPENROUTER_DYNAMIC_MARKER)
+ ]
+ app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
+ self._vision_configs = new_vision
+
+ # Catalogue churn invalidates per-config "recently healthy" credit
+ # earned by the previous turn's preflight. Drop the whole table so
+ # the next turn re-probes against the freshly loaded configs.
+ try:
+ from app.services.auto_model_pin_service import clear_healthy
+
+ clear_healthy()
+ except Exception:
+ logger.debug(
+ "OpenRouter refresh: clear_healthy import skipped", exc_info=True
+ )
+
+ tier_counts = self._tier_counts(new_configs)
+ logger.info(
+ "OpenRouter refresh: updated to %d models (free=%d, premium=%d)",
+ len(new_configs),
+ tier_counts["free"],
+ tier_counts["premium"],
+ )
+
+ # Re-blend health scores against the freshly fetched catalogue. Also
+ # re-stamps health for any YAML-curated cfg with provider==OPENROUTER
+ # so a hand-picked dead OR model is gated like a dynamic one.
+ await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
+
+ # Re-register LiteLLM pricing for the freshly fetched catalogue
+ # so newly added OR models bill correctly on their first call.
+ # Runs before the router rebuild because the router may issue
+ # cost-table lookups during deployment registration.
+ try:
+ from app.services.pricing_registration import (
+ register_pricing_from_global_configs,
+ )
+
+ register_pricing_from_global_configs()
+ except Exception as exc:
+ logger.warning(
+ "OpenRouter refresh: pricing re-registration skipped (%s)", exc
+ )
+
+ # Rebuild the LiteLLM router so freshly fetched configs flow through
+ # (dynamic OR premium entries now opt into the pool, free ones stay
+ # out; a refresh also needs to pick up any static-config edits and
+ # reset cached context-window profiles).
+ try:
+ from app.config import config as _app_config
+ from app.services.llm_router_service import (
+ LLMRouterService,
+ _router_instance_cache as _chat_router_cache,
+ )
+
+ LLMRouterService.rebuild(
+ _app_config.GLOBAL_LLM_CONFIGS,
+ getattr(_app_config, "ROUTER_SETTINGS", None),
+ )
+ _chat_router_cache.clear()
+ except Exception as exc:
+ logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc)
+
+ @staticmethod
+ def _tier_counts(configs: list[dict]) -> dict[str, int]:
+ counts = {"free": 0, "premium": 0}
+ for cfg in configs:
+ tier = str(cfg.get("billing_tier", "")).lower()
+ if tier in counts:
+ counts[tier] += 1
+ return counts
+
+ # ------------------------------------------------------------------
+ # Auto (Fastest) health enrichment
+ # ------------------------------------------------------------------
+
+ async def _enrich_health_safely(
+ self, configs: list[dict], *, log_summary: bool = True
+ ) -> None:
+ """Wrapper around ``_enrich_health`` that swallows all errors.
+
+ Health enrichment is best-effort: any failure must leave cfgs in
+ their static-only state and never break refresh / startup.
+ """
+ try:
+ await self._enrich_health(configs, log_summary=log_summary)
+ except Exception:
+ logger.exception("OpenRouter health enrichment failed")
+
+ async def _enrich_health(
+ self, configs: list[dict], *, log_summary: bool = True
+ ) -> None:
+ """Fetch per-model ``/endpoints`` data for the top OR cfgs and blend
+ the resulting health score into ``cfg["quality_score"]``.
+
+ Bounded fan-out: top-N per tier by ``quality_score_static`` only,
+ with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the
+ outbound HTTP. Misses fall back to a per-model last-good cache; if
+ the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep
+ the entire previous cycle's cache for this run.
+ """
+ or_cfgs = [
+ c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
+ ]
+ if not or_cfgs:
+ return
+
+ premium_pool = sorted(
+ [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"],
+ key=lambda c: -int(c.get("quality_score_static") or 0),
+ )[:_HEALTH_ENRICH_TOP_N_PREMIUM]
+ free_pool = sorted(
+ [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"],
+ key=lambda c: -int(c.get("quality_score_static") or 0),
+ )[:_HEALTH_ENRICH_TOP_N_FREE]
+ # De-duplicate while preserving order: a cfg shouldn't fall in both
+ # tiers, but defensive code is cheap here.
+ seen_ids: set[int] = set()
+ selected: list[dict] = []
+ for cfg in premium_pool + free_pool:
+ cid = int(cfg.get("id", 0))
+ if cid in seen_ids:
+ continue
+ seen_ids.add(cid)
+ selected.append(cfg)
+
+ if not selected:
+ return
+
+ api_key = str(self._settings.get("api_key") or "")
+ semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)
+
+ async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client:
+ results = await asyncio.gather(
+ *(
+ self._fetch_endpoints(client, semaphore, api_key, cfg)
+ for cfg in selected
+ )
+ )
+
+ fail_count = sum(1 for _, _, err in results if err is not None)
+ fail_ratio = fail_count / len(results) if results else 0.0
+ degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK
+ if degraded:
+ logger.warning(
+ "auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d "
+ "using_last_good_cache=true",
+ fail_ratio,
+ len(results),
+ )
+
+ # Per-cfg health update.
+ for cfg, endpoints, err in results:
+ model_name = str(cfg.get("model_name", ""))
+ if not degraded and err is None and endpoints is not None:
+ gated, h_score = aggregate_health(endpoints)
+ cfg["health_gated"] = bool(gated)
+ cfg["quality_score_health"] = h_score
+ self._health_cache[model_name] = {
+ "gated": bool(gated),
+ "score": h_score,
+ }
+ else:
+ cached = self._health_cache.get(model_name)
+ if cached is not None:
+ cfg["health_gated"] = bool(cached.get("gated", False))
+ cfg["quality_score_health"] = cached.get("score")
+ # else: keep current values (initial defaults from
+ # _generate_configs / load_global_llm_configs).
+
+ # Blend health into the final score for every OR cfg, including
+ # those outside the enriched top-N (they fall through to static).
+ gated_count = 0
+ by_provider: dict[str, int] = {}
+ for cfg in or_cfgs:
+ static_q = int(cfg.get("quality_score_static") or 0)
+ h = cfg.get("quality_score_health")
+ if h is not None and not cfg.get("health_gated"):
+ blended = (
+ _HEALTH_BLEND_WEIGHT * float(h)
+ + (1 - _HEALTH_BLEND_WEIGHT) * static_q
+ )
+ cfg["quality_score"] = round(blended)
+ else:
+ cfg["quality_score"] = static_q
+
+ if cfg.get("health_gated"):
+ gated_count += 1
+ model_id = str(cfg.get("model_name", ""))
+ provider_slug = (
+ model_id.split("/", 1)[0] if "/" in model_id else "unknown"
+ )
+ by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1
+
+ if log_summary:
+ logger.info(
+ "auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f "
+ "total_enriched=%d",
+ gated_count,
+ dict(sorted(by_provider.items(), key=lambda kv: -kv[1])),
+ fail_ratio,
+ len(selected),
+ )
+
+ @staticmethod
+ async def _fetch_endpoints(
+ client: httpx.AsyncClient,
+ semaphore: asyncio.Semaphore,
+ api_key: str,
+ cfg: dict,
+ ) -> tuple[dict, list[dict] | None, Exception | None]:
+ """Fetch ``/api/v1/models/{id}/endpoints`` for one cfg.
+
+ Returns ``(cfg, endpoints, err)`` so the caller can keep batched
+ results aligned with their cfgs without raising.
+ """
+ model_id = str(cfg.get("model_name", ""))
+ if not model_id:
+ return cfg, None, ValueError("missing model_name")
+
+ url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id)
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
+
+ async with semaphore:
+ try:
+ resp = await client.get(url, headers=headers)
+ resp.raise_for_status()
+ data = resp.json()
+ except Exception as exc:
+ return cfg, None, exc
+
+ payload = data.get("data") if isinstance(data, dict) else None
+ if not isinstance(payload, dict):
+ return cfg, None, ValueError("malformed endpoints payload")
+ endpoints = payload.get("endpoints")
+ if not isinstance(endpoints, list):
+ return cfg, [], None
+ return cfg, endpoints, None
async def _refresh_loop(self, interval_hours: float) -> None:
interval_sec = interval_hours * 3600
@@ -289,3 +957,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)
+
+ def get_image_generation_configs(self) -> list[dict]:
+ """Return the dynamic OpenRouter image-generation configs (empty
+ list when the ``image_generation_enabled`` flag is off).
+
+ Each entry already has ``billing_tier`` derived per-model from
+ OpenRouter's signals and is shaped to drop directly into
+ ``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
+ """
+ return list(self._image_configs)
+
+ def get_vision_llm_configs(self) -> list[dict]:
+ """Return the dynamic OpenRouter vision-LLM configs (empty list
+ when the ``vision_enabled`` flag is off).
+
+ Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
+ so ``pricing_registration`` can teach LiteLLM the cost of these
+ models the same way it does for chat — which keeps the billable
+ wrapper able to debit accurate micro-USD on a vision call.
+ """
+ return list(self._vision_configs)
+
+ def get_raw_pricing(self) -> dict[str, dict[str, str]]:
+ """Return the cached raw OpenRouter pricing map.
+
+ Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
+ values are the strings OpenRouter publishes (USD per token),
+ never converted to floats here so the caller can decide how to
+ handle malformed or unset entries.
+ """
+ return dict(self._raw_pricing)
diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py
new file mode 100644
index 000000000..de98e50c2
--- /dev/null
+++ b/surfsense_backend/app/services/pricing_registration.py
@@ -0,0 +1,274 @@
+"""
+Pricing registration with LiteLLM.
+
+Many models reach our LiteLLM callback without LiteLLM knowing their
+per-token cost — namely:
+
+* The ~300 dynamic OpenRouter deployments (their pricing only lives on
+ OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
+ pricing table).
+* Static YAML deployments whose ``base_model`` name is operator-defined
+ (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
+ not in LiteLLM's table either.
+
+Without registration, ``kwargs["response_cost"]`` is 0 for those calls
+and the user gets billed nothing — a fail-safe but wrong answer for a
+cost-based credit system. This module runs once at startup, after the
+OpenRouter integration has fetched its catalogue, and registers each
+known model's pricing with ``litellm.register_model()`` under multiple
+plausible alias keys (LiteLLM's cost lookup may use any of them
+depending on whether the call went through the Router, ChatLiteLLM,
+or a direct ``acompletion``).
+
+Operators who run a custom Azure deployment whose ``base_model`` name
+isn't in LiteLLM's table can declare per-token pricing inline in
+``global_llm_config.yaml`` via ``input_cost_per_token`` and
+``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
+that declaration the model's calls debit 0 — never overbilled.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+import litellm
+
+logger = logging.getLogger(__name__)
+
+
+def _safe_float(value: Any) -> float:
+ """Return ``float(value)`` if it parses to a positive number, else 0.0."""
+ if value is None:
+ return 0.0
+ try:
+ f = float(value)
+ except (TypeError, ValueError):
+ return 0.0
+ return f if f > 0 else 0.0
+
+
+def _alias_set_for_openrouter(model_id: str) -> list[str]:
+ """Return the alias keys to register an OpenRouter model under.
+
+ LiteLLM's cost-callback lookup key varies by call path:
+ - Router with ``model="openrouter/X"`` → kwargs["model"] is
+ typically ``openrouter/X``.
+ - LiteLLM's own provider routing may strip the prefix and pass the
+ bare ``X`` to the cost-table lookup.
+ Registering under both keeps the lookup hermetic regardless of
+ which path the call took.
+ """
+ aliases = [f"openrouter/{model_id}", model_id]
+ return list(dict.fromkeys(a for a in aliases if a))
+
+
+def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
+ """Return the alias keys to register a static YAML deployment under.
+
+ Same reasoning as the OpenRouter set: cover the bare ``base_model``,
+ the ``/`` form LiteLLM Router constructs, and the
+ bare ``model_name`` because callbacks sometimes see whichever was
+ configured first.
+ """
+ provider_lower = (provider or "").lower()
+ aliases: list[str] = []
+ if base_model:
+ aliases.append(base_model)
+ if provider_lower and base_model:
+ aliases.append(f"{provider_lower}/{base_model}")
+ if model_name and model_name != base_model:
+ aliases.append(model_name)
+ if provider_lower and model_name and model_name != base_model:
+ aliases.append(f"{provider_lower}/{model_name}")
+ # Azure deployments often surface as "azure/"; normalise the
+ # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
+ if provider_lower == "azure_openai":
+ if base_model:
+ aliases.append(f"azure/{base_model}")
+ if model_name and model_name != base_model:
+ aliases.append(f"azure/{model_name}")
+ return list(dict.fromkeys(a for a in aliases if a))
+
+
+def _register(
+ aliases: list[str],
+ *,
+ input_cost: float,
+ output_cost: float,
+ provider: str,
+ mode: str = "chat",
+) -> int:
+ """Register a single pricing entry under every alias in ``aliases``.
+
+ Returns the count of aliases successfully registered.
+ """
+ payload: dict[str, dict[str, Any]] = {}
+ for alias in aliases:
+ payload[alias] = {
+ "input_cost_per_token": input_cost,
+ "output_cost_per_token": output_cost,
+ "litellm_provider": provider,
+ "mode": mode,
+ }
+ if not payload:
+ return 0
+ try:
+ litellm.register_model(payload)
+ except Exception as exc:
+ logger.warning(
+ "[PricingRegistration] register_model failed for aliases=%s: %s",
+ aliases,
+ exc,
+ )
+ return 0
+ return len(payload)
+
+
+def _register_chat_shape_configs(
+ configs: list[dict],
+ *,
+ or_pricing: dict[str, dict[str, str]],
+ label: str,
+) -> tuple[int, int, int, list[str]]:
+ """Common loop that registers per-token pricing for a list of "chat-shape"
+ configs (chat or vision LLM — both use ``input_cost_per_token`` /
+ ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
+
+ Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
+ """
+ registered_models = 0
+ registered_aliases = 0
+ skipped_no_pricing = 0
+ sample_keys: list[str] = []
+
+ for cfg in configs:
+ provider = str(cfg.get("provider") or "").upper()
+ model_name = str(cfg.get("model_name") or "").strip()
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = str(litellm_params.get("base_model") or model_name).strip()
+
+ if provider == "OPENROUTER":
+ entry = or_pricing.get(model_name)
+ if entry:
+ input_cost = _safe_float(entry.get("prompt"))
+ output_cost = _safe_float(entry.get("completion"))
+ else:
+ # Vision configs from ``_generate_vision_llm_configs``
+ # carry their pricing inline because the OpenRouter
+ # raw-pricing cache is keyed by chat-catalogue model_id;
+ # vision flows pick up the inline values here.
+ input_cost = _safe_float(cfg.get("input_cost_per_token"))
+ output_cost = _safe_float(cfg.get("output_cost_per_token"))
+ if input_cost == 0.0 and output_cost == 0.0:
+ skipped_no_pricing += 1
+ continue
+ aliases = _alias_set_for_openrouter(model_name)
+ count = _register(
+ aliases,
+ input_cost=input_cost,
+ output_cost=output_cost,
+ provider="openrouter",
+ )
+ if count > 0:
+ registered_models += 1
+ registered_aliases += count
+ if len(sample_keys) < 6:
+ sample_keys.extend(aliases[:2])
+ continue
+
+ input_cost = _safe_float(
+ cfg.get("input_cost_per_token")
+ or litellm_params.get("input_cost_per_token")
+ )
+ output_cost = _safe_float(
+ cfg.get("output_cost_per_token")
+ or litellm_params.get("output_cost_per_token")
+ )
+ if input_cost == 0.0 and output_cost == 0.0:
+ skipped_no_pricing += 1
+ continue
+ aliases = _alias_set_for_yaml(provider, model_name, base_model)
+ provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
+ count = _register(
+ aliases,
+ input_cost=input_cost,
+ output_cost=output_cost,
+ provider=provider_slug,
+ )
+ if count > 0:
+ registered_models += 1
+ registered_aliases += count
+ if len(sample_keys) < 6:
+ sample_keys.extend(aliases[:2])
+
+ logger.info(
+ "[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
+ "%d configs had no pricing data; sample registered keys=%s",
+ label,
+ registered_models,
+ registered_aliases,
+ skipped_no_pricing,
+ sample_keys,
+ )
+ return registered_models, registered_aliases, skipped_no_pricing, sample_keys
+
+
+def register_pricing_from_global_configs() -> None:
+ """Register pricing for every known LLM deployment with LiteLLM.
+
+ Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
+ so vision calls (during indexing) can resolve cost the same way chat
+ calls do — namely:
+
+ 1. ``OPENROUTER``: pulls the cached raw pricing from
+ ``OpenRouterIntegrationService`` (populated during its own
+ startup fetch) and converts the per-token strings to floats. For
+ vision configs that carry pricing inline (``input_cost_per_token`` /
+ ``output_cost_per_token`` set on the cfg itself) we fall back to
+ those values when the OR cache misses the model.
+ 2. Anything else: looks for operator-declared
+ ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
+ config block (top-level or nested under ``litellm_params``).
+
+ **Image generation is intentionally NOT registered here.** The cost
+ shape for image-gen is per-image (``output_cost_per_image``), not
+ per-token, and LiteLLM's ``register_model`` doesn't accept those
+ keys via the chat-cost path. OpenRouter image-gen models populate
+ ``response_cost`` directly from their response header instead, and
+ Azure-native image-gen models are already in LiteLLM's cost map.
+
+ Calls without a resolved pair of costs are skipped, not registered
+ with zeros — operators who forget pricing get a "$0 debit" warning
+ in ``TokenTrackingCallback`` rather than silently overwriting any
+ pricing LiteLLM might know natively.
+ """
+ from app.config import config as app_config
+
+ chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
+ vision_configs: list[dict] = list(
+ getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
+ )
+ if not chat_configs and not vision_configs:
+ logger.info("[PricingRegistration] no global configs to register")
+ return
+
+ or_pricing: dict[str, dict[str, str]] = {}
+ try:
+ from app.services.openrouter_integration_service import (
+ OpenRouterIntegrationService,
+ )
+
+ if OpenRouterIntegrationService.is_initialized():
+ or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
+ except Exception as exc:
+ logger.debug(
+ "[PricingRegistration] OpenRouter pricing not available yet: %s", exc
+ )
+
+ if chat_configs:
+ _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
+ if vision_configs:
+ _register_chat_shape_configs(
+ vision_configs, or_pricing=or_pricing, label="vision"
+ )
diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py
new file mode 100644
index 000000000..dca1f9462
--- /dev/null
+++ b/surfsense_backend/app/services/provider_api_base.py
@@ -0,0 +1,106 @@
+"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
+
+LiteLLM falls back to the module-global ``litellm.api_base`` when an
+individual call doesn't pass one, which silently inherits provider-agnostic
+env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
+explicit ``api_base``, an ``openrouter/`` request can end up at an
+Azure endpoint and 404 with ``Resource not found`` (real reproducer:
+[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
+``/chat/completions`` to whatever inherited base it gets, regardless of
+provider).
+
+The chat router has had this defense for a while
+(``llm_router_service.py:466-478``). This module hoists the maps + cascade
+into a tiny standalone helper so vision and image-gen can share the same
+source of truth without an inter-service circular import.
+"""
+
+from __future__ import annotations
+
+PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
+ "openrouter": "https://openrouter.ai/api/v1",
+ "groq": "https://api.groq.com/openai/v1",
+ "mistral": "https://api.mistral.ai/v1",
+ "perplexity": "https://api.perplexity.ai",
+ "xai": "https://api.x.ai/v1",
+ "cerebras": "https://api.cerebras.ai/v1",
+ "deepinfra": "https://api.deepinfra.com/v1/openai",
+ "fireworks_ai": "https://api.fireworks.ai/inference/v1",
+ "together_ai": "https://api.together.xyz/v1",
+ "anyscale": "https://api.endpoints.anyscale.com/v1",
+ "cometapi": "https://api.cometapi.com/v1",
+ "sambanova": "https://api.sambanova.ai/v1",
+}
+"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
+
+Only providers with a well-known, stable public base URL are listed —
+self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
+huggingface, databricks, cloudflare, replicate) are intentionally omitted
+so their existing config-driven behaviour is preserved."""
+
+
+PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
+ "DEEPSEEK": "https://api.deepseek.com/v1",
+ "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
+ "MOONSHOT": "https://api.moonshot.ai/v1",
+ "ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
+ "MINIMAX": "https://api.minimax.io/v1",
+}
+"""Canonical provider key (uppercase) → base URL.
+
+Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
+config's ``provider`` field tells us which API it actually is (DeepSeek,
+Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
+has its own base URL)."""
+
+
+def resolve_api_base(
+ *,
+ provider: str | None,
+ provider_prefix: str | None,
+ config_api_base: str | None,
+) -> str | None:
+ """Resolve a non-Azure-leaking ``api_base`` for a deployment.
+
+ Cascade (first non-empty wins):
+ 1. The config's own ``api_base`` (whitespace-only treated as missing).
+ 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
+ 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
+ 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
+ provider integration apply its own default (e.g. AzureOpenAI's
+ deployment-derived URL, custom provider's per-deployment URL).
+
+ Args:
+ provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
+ ``"DEEPSEEK"``). Case-insensitive.
+ provider_prefix: The LiteLLM model-string prefix the same call
+ site builds for the model id (e.g. ``"openrouter"``,
+ ``"groq"``). Case-insensitive.
+ config_api_base: ``api_base`` from the global YAML / DB row /
+ OpenRouter dynamic config. Empty / whitespace-only means
+ "missing" — the resolver still applies the cascade.
+
+ Returns:
+ A URL string, or ``None`` if no default applies for this provider.
+ """
+ if config_api_base and config_api_base.strip():
+ return config_api_base
+
+ if provider:
+ key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
+ if key_default:
+ return key_default
+
+ if provider_prefix:
+ prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
+ if prefix_default:
+ return prefix_default
+
+ return None
+
+
+__all__ = [
+ "PROVIDER_DEFAULT_API_BASE",
+ "PROVIDER_KEY_DEFAULT_API_BASE",
+ "resolve_api_base",
+]
diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py
new file mode 100644
index 000000000..e9a1c33e1
--- /dev/null
+++ b/surfsense_backend/app/services/provider_capabilities.py
@@ -0,0 +1,280 @@
+"""Capability resolution shared by chat / image / vision call sites.
+
+Why this exists
+---------------
+The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
+single, authoritative answer to one question: *can this chat config accept
+``image_url`` content blocks?* Without it, the new-chat selector can't badge
+incompatible models and the streaming task can't fail fast with a friendly
+error before sending an image to a text-only provider.
+
+Two functions, two intents:
+
+- :func:`derive_supports_image_input` — best-effort *True* for catalog and
+ UI surfacing. Default-allow: an unknown / unmapped model is treated as
+ capable so we never lock the user out of a freshly added or
+ third-party-hosted vision model.
+
+- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming
+ task's safety net. Returns True only when LiteLLM's model map *explicitly*
+ sets ``supports_vision=False`` (or its bare-name variant does). Anything
+ else — missing key, lookup exception, ``supports_vision=True`` — returns
+ False so the request flows through to the provider.
+
+Implementation rule: only public LiteLLM symbols
+------------------------------------------------
+``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
+typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
+across releases. The private ``_is_explicitly_disabled_factory`` and
+``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
+can't silently break us.
+
+Why the previous round's strict YAML opt-in flag failed
+-------------------------------------------------------
+``supports_image_input: false`` was the YAML loader's setdefault. Operators
+maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
+YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to
+False and the streaming gate rejected every image turn. Sourcing capability
+from LiteLLM's authoritative model map (which already says
+``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Iterable
+
+import litellm
+
+logger = logging.getLogger(__name__)
+
+
+# Provider-name → LiteLLM model-prefix map.
+#
+# Owned here because ``app.services.provider_capabilities`` is the
+# only edge that's safe to call from ``app.config``'s YAML loader at
+# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
+# this constant under the historical ``PROVIDER_MAP`` name; placing the
+# map there directly would re-introduce the
+# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
+# app.config`` cycle that prompted the move.
+_PROVIDER_PREFIX_MAP: dict[str, str] = {
+ "OPENAI": "openai",
+ "ANTHROPIC": "anthropic",
+ "GROQ": "groq",
+ "COHERE": "cohere",
+ "GOOGLE": "gemini",
+ "OLLAMA": "ollama_chat",
+ "MISTRAL": "mistral",
+ "AZURE_OPENAI": "azure",
+ "OPENROUTER": "openrouter",
+ "XAI": "xai",
+ "BEDROCK": "bedrock",
+ "VERTEX_AI": "vertex_ai",
+ "TOGETHER_AI": "together_ai",
+ "FIREWORKS_AI": "fireworks_ai",
+ "DEEPSEEK": "openai",
+ "ALIBABA_QWEN": "openai",
+ "MOONSHOT": "openai",
+ "ZHIPU": "openai",
+ "GITHUB_MODELS": "github",
+ "REPLICATE": "replicate",
+ "PERPLEXITY": "perplexity",
+ "ANYSCALE": "anyscale",
+ "DEEPINFRA": "deepinfra",
+ "CEREBRAS": "cerebras",
+ "SAMBANOVA": "sambanova",
+ "AI21": "ai21",
+ "CLOUDFLARE": "cloudflare",
+ "DATABRICKS": "databricks",
+ "COMETAPI": "cometapi",
+ "HUGGINGFACE": "huggingface",
+ "MINIMAX": "openai",
+ "CUSTOM": "custom",
+}
+
+
+def _candidate_model_strings(
+ *,
+ provider: str | None,
+ model_name: str | None,
+ base_model: str | None,
+ custom_provider: str | None,
+) -> list[tuple[str, str | None]]:
+ """Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
+
+ LiteLLM's capability lookup is keyed by ``model`` + (optional)
+ ``custom_llm_provider``. Different config sources give us different
+ levels of detail, so we try the most-specific keys first and fall back
+ to bare model names so unannotated entries (e.g. an Azure deployment
+ pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
+ map. Order matters — the first lookup that returns a definitive answer
+ wins for both helpers.
+ """
+ candidates: list[tuple[str, str | None]] = []
+ seen: set[tuple[str, str | None]] = set()
+
+ def _add(model: str | None, llm_provider: str | None) -> None:
+ if not model:
+ return
+ key = (model, llm_provider)
+ if key in seen:
+ return
+ seen.add(key)
+ candidates.append(key)
+
+ provider_prefix: str | None = None
+ if provider:
+ provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
+ if custom_provider:
+ # ``custom_provider`` overrides everything for CUSTOM/proxy setups.
+ provider_prefix = custom_provider
+
+ primary_model = base_model or model_name
+ bare_model = model_name
+
+ # Most-specific first: provider-prefixed identifier with explicit
+ # custom_llm_provider so LiteLLM won't have to guess the provider via
+ # ``get_llm_provider``.
+ if primary_model and provider_prefix:
+ # e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
+ if "/" in primary_model:
+ _add(primary_model, provider_prefix)
+ else:
+ _add(f"{provider_prefix}/{primary_model}", provider_prefix)
+
+ # Bare base_model (or model_name) with provider hint — handles entries
+ # the upstream map keys without a provider prefix (most ``gpt-*`` and
+ # ``claude-*`` entries do this).
+ if primary_model:
+ _add(primary_model, provider_prefix)
+
+ # Fallback to model_name when base_model differs (e.g. an Azure
+ # deployment whose model_name is the deployment id but base_model is the
+ # canonical OpenAI sku).
+ if bare_model and bare_model != primary_model:
+ if provider_prefix and "/" not in bare_model:
+ _add(f"{provider_prefix}/{bare_model}", provider_prefix)
+ _add(bare_model, provider_prefix)
+ _add(bare_model, None)
+
+ return candidates
+
+
+def derive_supports_image_input(
+ *,
+ provider: str | None = None,
+ model_name: str | None = None,
+ base_model: str | None = None,
+ custom_provider: str | None = None,
+ openrouter_input_modalities: Iterable[str] | None = None,
+) -> bool:
+ """Best-effort capability flag for the new-chat selector and catalog.
+
+ Resolution order (first definitive answer wins):
+
+ 1. ``openrouter_input_modalities`` (when provided as a non-empty
+ iterable). OpenRouter exposes ``architecture.input_modalities`` per
+ model and that's the authoritative source for OR dynamic configs.
+ 2. ``litellm.supports_vision`` against each candidate identifier from
+ :func:`_candidate_model_strings`. Returns True as soon as any
+ candidate confirms vision support.
+ 3. Default ``True`` — the conservative-allow stance. An unknown /
+ newly-added / third-party-hosted model is *not* pre-judged. The
+ streaming safety net (:func:`is_known_text_only_chat_model`) is the
+ only place a False ever blocks; everywhere else, a False here would
+ just hide a usable model from the user.
+
+ Returns:
+ True if the model can plausibly accept image input, False only when
+ OpenRouter explicitly says it can't.
+ """
+ if openrouter_input_modalities is not None:
+ modalities = list(openrouter_input_modalities)
+ if modalities:
+ return "image" in modalities
+ # Empty list explicitly published by OR — treat as "no image".
+ return False
+
+ for model_string, custom_llm_provider in _candidate_model_strings(
+ provider=provider,
+ model_name=model_name,
+ base_model=base_model,
+ custom_provider=custom_provider,
+ ):
+ try:
+ if litellm.supports_vision(
+ model=model_string, custom_llm_provider=custom_llm_provider
+ ):
+ return True
+ except Exception as exc:
+ logger.debug(
+ "litellm.supports_vision raised for model=%s provider=%s: %s",
+ model_string,
+ custom_llm_provider,
+ exc,
+ )
+ continue
+
+ # Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
+ return True
+
+
+def is_known_text_only_chat_model(
+ *,
+ provider: str | None = None,
+ model_name: str | None = None,
+ base_model: str | None = None,
+ custom_provider: str | None = None,
+) -> bool:
+ """Strict opt-out probe for the streaming-task safety net.
+
+ Returns True only when LiteLLM's model map *explicitly* sets
+ ``supports_vision=False`` for at least one candidate identifier. Missing
+ key, lookup exception, or ``supports_vision=True`` all return False so
+ the streaming task lets the request through. This is the inverse-default
+ of :func:`derive_supports_image_input`.
+
+ Why two functions
+ -----------------
+ The selector wants "show me everything that's plausibly capable" —
+ default-allow. The safety net wants "block only when I'm certain it
+ can't" — default-pass. Mixing the two intents in a single function
+ leads to the regression we're fixing here.
+ """
+ for model_string, custom_llm_provider in _candidate_model_strings(
+ provider=provider,
+ model_name=model_name,
+ base_model=base_model,
+ custom_provider=custom_provider,
+ ):
+ try:
+ info = litellm.get_model_info(
+ model=model_string, custom_llm_provider=custom_llm_provider
+ )
+ except Exception as exc:
+ logger.debug(
+ "litellm.get_model_info raised for model=%s provider=%s: %s",
+ model_string,
+ custom_llm_provider,
+ exc,
+ )
+ continue
+
+ # ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
+ # may be missing, None, True, or False. We only fire on explicit
+ # False — None / missing / True all mean "don't block".
+ try:
+ value = info.get("supports_vision") # type: ignore[union-attr]
+ except AttributeError:
+ value = None
+ if value is False:
+ return True
+
+ return False
+
+
+__all__ = [
+ "derive_supports_image_input",
+ "is_known_text_only_chat_model",
+]
diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py
new file mode 100644
index 000000000..2fb37de21
--- /dev/null
+++ b/surfsense_backend/app/services/quality_score.py
@@ -0,0 +1,380 @@
+"""Pure-function quality scoring for Auto (Fastest) model selection.
+
+This module is import-free of any service / request-path dependencies. All
+numbers are computed once during the OpenRouter refresh tick (or YAML load)
+and cached on the cfg dict, so the chat hot path only does a precomputed
+sort and a SHA256 pick.
+
+Score components (0-100 scale, higher is better):
+
+* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload
+ (provider prestige + ``created`` recency + pricing band + context window
+ + capabilities + narrow tiny/legacy slug penalty).
+* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus
+ an operator-trust bonus (the operator deliberately picked this model).
+* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints``
+ responses; returns ``(gated, score_or_none)``.
+
+The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in
+:mod:`app.services.openrouter_integration_service` because that's the only
+caller that sees both halves.
+"""
+
+from __future__ import annotations
+
+# ---------------------------------------------------------------------------
+# Tunables (constants, not flags)
+# ---------------------------------------------------------------------------
+
+# Top-K size for deterministic spread inside the locked tier.
+_QUALITY_TOP_K: int = 5
+
+# Hard health gate: any cfg whose best non-null uptime is below this %
+# is excluded from Auto-mode selection entirely.
+_HEALTH_GATE_UPTIME_PCT: float = 90.0
+
+# Health/static blend weight when a cfg has fresh /endpoints data.
+_HEALTH_BLEND_WEIGHT: float = 0.5
+
+# Static bonus applied to YAML cfgs because the operator hand-picked them.
+_OPERATOR_TRUST_BONUS: int = 20
+
+# /endpoints fan-out is bounded per refresh tick.
+_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50
+_HEALTH_ENRICH_TOP_N_FREE: int = 30
+_HEALTH_ENRICH_CONCURRENCY: int = 15
+_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0
+
+# If at least this fraction of /endpoints fetches fail in a refresh cycle,
+# fall back to the previous cycle's last-good cache instead of writing
+# partial / stale health values.
+_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25
+
+# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise
+# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with
+# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and
+# blanket-penalising them suppresses high-quality picks.
+_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = (
+ "-1b-",
+ "-1.2b-",
+ "-1.5b-",
+ "-2b-",
+ "-3b-",
+ "gemma-3n",
+ "lfm-",
+ "-base",
+ "-distill",
+ ":nitro",
+ "-preview",
+)
+
+
+# ---------------------------------------------------------------------------
+# Provider prestige tables
+# ---------------------------------------------------------------------------
+
+# OpenRouter-side provider slug (the prefix before ``/`` in the model id).
+# Tiers are coarse: frontier labs > strong open / fast-moving labs >
+# specialist labs > everything else.
+PROVIDER_PRESTIGE_OR: dict[str, int] = {
+ # Frontier labs
+ "openai": 50,
+ "anthropic": 50,
+ "google": 50,
+ "x-ai": 50,
+ # Strong open / fast-moving labs
+ "deepseek": 38,
+ "qwen": 38,
+ "meta-llama": 38,
+ "mistralai": 38,
+ "cohere": 38,
+ "nvidia": 38,
+ "alibaba": 38,
+ # Specialist / regional / strong second-tier
+ "microsoft": 28,
+ "01-ai": 28,
+ "minimax": 28,
+ "moonshot": 28,
+ "z-ai": 28,
+ "nousresearch": 28,
+ "ai21": 28,
+ "perplexity": 28,
+ # Smaller / niche providers
+ "liquid": 18,
+ "cognitivecomputations": 18,
+ "venice": 18,
+ "inflection": 18,
+}
+
+# YAML provider field (the upstream API shape the operator selected).
+PROVIDER_PRESTIGE_YAML: dict[str, int] = {
+ "AZURE_OPENAI": 50,
+ "OPENAI": 50,
+ "ANTHROPIC": 50,
+ "GOOGLE": 50,
+ "VERTEX_AI": 50,
+ "GEMINI": 50,
+ "XAI": 50,
+ "MISTRAL": 38,
+ "DEEPSEEK": 38,
+ "COHERE": 38,
+ "GROQ": 30,
+ "TOGETHER_AI": 28,
+ "FIREWORKS_AI": 28,
+ "PERPLEXITY": 28,
+ "MINIMAX": 28,
+ "BEDROCK": 28,
+ "OPENROUTER": 25,
+ "OLLAMA": 12,
+ "CUSTOM": 12,
+}
+
+
+# ---------------------------------------------------------------------------
+# Pure scoring helpers
+# ---------------------------------------------------------------------------
+
+# Calibrated against the live /api/v1/models bulk dump. Frontier models
+# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5,
+# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band;
+# anything older trails off.
+_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = (
+ (60, 20),
+ (180, 16),
+ (365, 12),
+ (540, 9),
+ (730, 6),
+ (1095, 3),
+)
+
+
+def created_recency_signal(created_ts: int | None, now_ts: int) -> int:
+ """Return 0-20 based on how recently the model was published.
+
+ Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for
+ YAML cfgs). Models without a usable timestamp get 0 (we don't penalise,
+ we just don't reward).
+ """
+ if created_ts is None or created_ts <= 0 or now_ts <= 0:
+ return 0
+ age_days = max(0, (now_ts - int(created_ts)) // 86_400)
+ for cutoff, score in _RECENCY_BANDS_DAYS:
+ if age_days <= cutoff:
+ return score
+ return 0
+
+
+def pricing_band(
+ prompt: str | float | int | None,
+ completion: str | float | int | None,
+) -> int:
+ """Return 0-15 based on combined prompt+completion cost per 1M tokens.
+
+ Higher-priced models tend to be the larger / more capable ones. A free
+ model returns 0 (we use other signals to rank free-vs-free instead).
+ Uncoercible inputs are treated as 0 rather than raising.
+ """
+
+ def _to_float(value) -> float:
+ if value is None:
+ return 0.0
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return 0.0
+
+ p = _to_float(prompt)
+ c = _to_float(completion)
+ total_per_million = (p + c) * 1_000_000
+
+ if total_per_million >= 20.0:
+ return 15
+ if total_per_million >= 5.0:
+ return 12
+ if total_per_million >= 1.0:
+ return 9
+ if total_per_million >= 0.3:
+ return 6
+ if total_per_million >= 0.05:
+ return 4
+ if total_per_million > 0.0:
+ return 2
+ return 0
+
+
+def context_signal(ctx: int | None) -> int:
+ """Return 0-10 based on the model's context window."""
+ if not ctx or ctx <= 0:
+ return 0
+ if ctx >= 1_000_000:
+ return 10
+ if ctx >= 400_000:
+ return 8
+ if ctx >= 200_000:
+ return 6
+ if ctx >= 128_000:
+ return 4
+ if ctx >= 100_000:
+ return 2
+ return 0
+
+
+def capabilities_signal(supported_parameters: list[str] | None) -> int:
+ """Return 0-5 for capabilities that matter for our agent flows."""
+ if not supported_parameters:
+ return 0
+ params = set(supported_parameters)
+ score = 0
+ if "tools" in params:
+ score += 2
+ if "structured_outputs" in params or "response_format" in params:
+ score += 2
+ if "reasoning" in params or "include_reasoning" in params:
+ score += 1
+ return min(score, 5)
+
+
+def slug_penalty(model_id: str) -> int:
+ """Return a non-positive number; matches the narrow tiny/legacy patterns."""
+ if not model_id:
+ return 0
+ needle = model_id.lower()
+ for pattern in _TINY_LEGACY_PENALTY_PATTERNS:
+ if pattern in needle:
+ return -10
+ return 0
+
+
+def _provider_prestige_or(model_id: str) -> int:
+ if "/" not in model_id:
+ return 0
+ slug = model_id.split("/", 1)[0].lower()
+ return PROVIDER_PRESTIGE_OR.get(slug, 15)
+
+
+def static_score_or(or_model: dict, *, now_ts: int) -> int:
+ """Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale."""
+ model_id = str(or_model.get("id", ""))
+ pricing = or_model.get("pricing") or {}
+
+ score = (
+ _provider_prestige_or(model_id)
+ + created_recency_signal(or_model.get("created"), now_ts)
+ + pricing_band(pricing.get("prompt"), pricing.get("completion"))
+ + context_signal(or_model.get("context_length"))
+ + capabilities_signal(or_model.get("supported_parameters"))
+ + slug_penalty(model_id)
+ )
+ return max(0, min(100, int(score)))
+
+
+def static_score_yaml(cfg: dict) -> int:
+ """Score a YAML-curated cfg on a 0-100 scale.
+
+ Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately
+ listed this model. Pricing / context fall through to lazy ``litellm``
+ lookups; failures are silent (we just lose those sub-points).
+ """
+ provider = str(cfg.get("provider", "")).upper()
+ base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
+
+ model_name = cfg.get("model_name") or ""
+ litellm_params = cfg.get("litellm_params") or {}
+ lookup_name = (
+ litellm_params.get("base_model") or litellm_params.get("model") or model_name
+ )
+
+ ctx = 0
+ p_cost: float = 0.0
+ c_cost: float = 0.0
+ try:
+ from litellm import get_model_info # lazy: avoid cold-import cost
+
+ info = get_model_info(lookup_name) or {}
+ ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0)
+ p_cost = float(info.get("input_cost_per_token") or 0.0)
+ c_cost = float(info.get("output_cost_per_token") or 0.0)
+ except Exception:
+ # Unknown to litellm — that's fine for prestige+operator-bonus weighting.
+ pass
+
+ score = (
+ base
+ + _OPERATOR_TRUST_BONUS
+ + pricing_band(p_cost, c_cost)
+ + context_signal(ctx)
+ + slug_penalty(str(model_name))
+ )
+ return max(0, min(100, int(score)))
+
+
+# ---------------------------------------------------------------------------
+# Health aggregation
+# ---------------------------------------------------------------------------
+
+
+def _coerce_pct(value) -> float | None:
+ try:
+ if value is None:
+ return None
+ f = float(value)
+ except (TypeError, ValueError):
+ return None
+ if f < 0:
+ return None
+ # OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it
+ # as a 0-100 percentage. Normalise.
+ return f * 100.0 if f <= 1.0 else f
+
+
+def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]:
+ """Pick the best (highest) non-null uptime across all endpoints.
+
+ Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` >
+ ``uptime_last_5m``. Returns ``(uptime_pct, window_used)``.
+ """
+ for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"):
+ values = [_coerce_pct(ep.get(window)) for ep in endpoints]
+ values = [v for v in values if v is not None]
+ if values:
+ return max(values), window
+ return None, None
+
+
+def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]:
+ """Aggregate a model's per-endpoint health into ``(gated, score_or_none)``.
+
+ Hard gate (returns ``(True, None)``):
+ * ``endpoints`` empty,
+ * no endpoint reports ``status == 0`` (OK), or
+ * best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``.
+
+ On a pass, returns a 0-100 health score blending uptime, status, and a
+ freshness-weighted recent uptime sample.
+ """
+ if not endpoints:
+ return True, None
+
+ any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints)
+ if not any_ok:
+ return True, None
+
+ best_uptime, _ = _best_uptime(endpoints)
+ if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT:
+ return True, None
+
+ # Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing.
+ freshness = None
+ for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"):
+ values = [_coerce_pct(ep.get(window)) for ep in endpoints]
+ values = [v for v in values if v is not None]
+ if values:
+ freshness = max(values)
+ break
+
+ uptime_term = best_uptime
+ status_term = 100.0 if any_ok else 0.0
+ freshness_term = freshness if freshness is not None else best_uptime
+
+ score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term
+ return False, max(0.0, min(100.0, score))
diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py
new file mode 100644
index 000000000..0040e5a5b
--- /dev/null
+++ b/surfsense_backend/app/services/quota_checked_vision_llm.py
@@ -0,0 +1,105 @@
+"""
+Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
+
+Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
+indexing pipeline (file processors, connector indexers, etl pipeline) can
+keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
+— without threading ``user_id`` through every parser. The wrapper looks like
+a chat model from the outside; on the inside it routes each call through
+``billable_call`` so the user's premium credit pool is reserved → finalized
+or released, and a ``TokenUsage`` audit row is written.
+
+Free configs are returned unwrapped from ``get_vision_llm`` (they do not
+need quota enforcement) so this class only ever wraps premium configs.
+
+Why a wrapper instead of plumbing ``user_id`` through every caller:
+
+* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
+ Dropbox, local-folder, file-processor, ETL pipeline) each calling
+ ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
+ invasive, error-prone, and easy for a future indexer to forget.
+* Per the design (issue M), we always debit the *search-space owner*, not
+ the triggering user, so ``user_id`` is fully derivable from the search
+ space the caller is already operating on. The wrapper captures it once
+ at construction time.
+* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
+ call run this coroutine"; subclassing isn't safe across versions because
+ it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
+ Composition via attribute proxying (``__getattr__``) is robust to
+ upstream changes — every method other than ``ainvoke`` falls through to
+ the inner LLM unchanged.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+from uuid import UUID
+
+from app.services.billable_calls import QuotaInsufficientError, billable_call
+
+logger = logging.getLogger(__name__)
+
+
+class QuotaCheckedVisionLLM:
+ """Composition wrapper around a langchain chat model that enforces
+ premium credit quota on every ``ainvoke``.
+
+ Anything other than ``ainvoke`` is forwarded to the inner model so
+ ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
+ still work — they simply bypass quota enforcement, which is fine
+ because the indexing pipeline only ever calls ``ainvoke`` today.
+ """
+
+ def __init__(
+ self,
+ inner_llm: Any,
+ *,
+ user_id: UUID,
+ search_space_id: int,
+ billing_tier: str,
+ base_model: str,
+ quota_reserve_tokens: int | None,
+ usage_type: str = "vision_extraction",
+ ) -> None:
+ self._inner = inner_llm
+ self._user_id = user_id
+ self._search_space_id = search_space_id
+ self._billing_tier = billing_tier
+ self._base_model = base_model
+ self._quota_reserve_tokens = quota_reserve_tokens
+ self._usage_type = usage_type
+
+ async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
+ """Proxied async invoke that runs the underlying call inside
+ ``billable_call``.
+
+ Raises:
+ QuotaInsufficientError: when the user has exhausted their
+ premium credit pool. Caller (``etl_pipeline_service._extract_image``)
+ catches this and falls back to the document parser.
+ """
+ async with billable_call(
+ user_id=self._user_id,
+ search_space_id=self._search_space_id,
+ billing_tier=self._billing_tier,
+ base_model=self._base_model,
+ quota_reserve_tokens=self._quota_reserve_tokens,
+ usage_type=self._usage_type,
+ call_details={"model": self._base_model},
+ ):
+ return await self._inner.ainvoke(input, *args, **kwargs)
+
+ def __getattr__(self, name: str) -> Any:
+ """Forward everything else (``invoke``, ``astream``, ``bind``,
+ ``with_structured_output``, …) to the inner model.
+
+ ``__getattr__`` is only consulted when the attribute is *not*
+ already found on the proxy, which is exactly the contract we
+ want — methods we override stay on the proxy, the rest fall
+ through.
+ """
+ return getattr(self._inner, name)
+
+
+__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py
index a3ec7aed0..310c3eb5e 100644
--- a/surfsense_backend/app/services/token_quota_service.py
+++ b/surfsense_backend/app/services/token_quota_service.py
@@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__)
+# ---------------------------------------------------------------------------
+# Per-call reservation estimator (USD micro-units)
+# ---------------------------------------------------------------------------
+
+# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
+# request, and so models without registered pricing reserve at least
+# something while the call runs (debited 0 at finalize anyway when their
+# cost can't be resolved).
+_QUOTA_MIN_RESERVE_MICROS = 100
+
+
+def estimate_call_reserve_micros(
+ *,
+ base_model: str,
+ quota_reserve_tokens: int | None,
+) -> int:
+ """Return the number of micro-USD to reserve for one premium call.
+
+ Computes a worst-case upper bound from LiteLLM's per-token pricing
+ table:
+
+ reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
+
+ so the math scales with model cost — Claude Opus + 4K reserve_tokens
+ naturally reserves ≈ $0.36, while a cheap model reserves only a few
+ cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
+ so a misconfigured "$1000/M" model can't lock the whole balance on
+ one call.
+
+ If ``litellm.get_model_info`` raises (model unknown) we fall back to
+ the floor — 100 micros / $0.0001 — which is enough to gate a sane
+ request without over-reserving for a model whose pricing the
+ operator hasn't declared yet.
+ """
+ reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
+ if reserve_tokens <= 0:
+ reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
+
+ try:
+ from litellm import get_model_info
+
+ info = get_model_info(base_model) if base_model else {}
+ input_cost = float(info.get("input_cost_per_token") or 0.0)
+ output_cost = float(info.get("output_cost_per_token") or 0.0)
+ except Exception as exc:
+ logger.debug(
+ "[quota_reserve] cost lookup failed for base_model=%s: %s",
+ base_model,
+ exc,
+ )
+ input_cost = 0.0
+ output_cost = 0.0
+
+ if input_cost == 0.0 and output_cost == 0.0:
+ return _QUOTA_MIN_RESERVE_MICROS
+
+ reserve_usd = reserve_tokens * (input_cost + output_cost)
+ reserve_micros = round(reserve_usd * 1_000_000)
+ if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
+ reserve_micros = _QUOTA_MIN_RESERVE_MICROS
+ if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
+ reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
+ return reserve_micros
+
+
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
@@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
- reserve_tokens: int,
+ reserve_micros: int,
) -> QuotaResult:
+ """Reserve ``reserve_micros`` (USD micro-units) from the user's
+ premium credit balance.
+
+ ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
+ all in micro-USD on this code path; callers (chat stream,
+ token-status route, FE display) convert to dollars by dividing
+ by 1_000_000.
+ """
from app.db import User
user = (
@@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0,
)
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
- effective = used + reserved + reserve_tokens
+ effective = used + reserved + reserve_micros
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
@@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining,
)
- user.premium_tokens_reserved = reserved + reserve_tokens
+ user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit()
- new_reserved = reserved + reserve_tokens
+ new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
@@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
- actual_tokens: int,
- reserved_tokens: int,
+ actual_micros: int,
+ reserved_micros: int,
) -> QuotaResult:
+ """Settle the reservation: release ``reserved_micros`` and debit
+ ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
+ """
from app.db import User
user = (
@@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
- user.premium_tokens_reserved = max(
- 0, user.premium_tokens_reserved - reserved_tokens
+ user.premium_credit_micros_reserved = max(
+ 0, user.premium_credit_micros_reserved - reserved_micros
+ )
+ user.premium_credit_micros_used = (
+ user.premium_credit_micros_used + actual_micros
)
- user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
@@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release(
db_session: AsyncSession,
user_id: Any,
- reserved_tokens: int,
+ reserved_micros: int,
) -> None:
+ """Release ``reserved_micros`` previously held by ``premium_reserve``.
+
+ Used when a request fails before finalize (so the reservation
+ doesn't leak credit).
+ """
from app.db import User
user = (
@@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none()
)
if user is not None:
- user.premium_tokens_reserved = max(
- 0, user.premium_tokens_reserved - reserved_tokens
+ user.premium_credit_micros_reserved = max(
+ 0, user.premium_credit_micros_reserved - reserved_micros
)
await db_session.commit()
@@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py
index 9aa8c6e70..9406d9be4 100644
--- a/surfsense_backend/app/services/token_tracking_service.py
+++ b/surfsense_backend/app/services/token_tracking_service.py
@@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses
import logging
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
+import litellm
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
@@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int
completion_tokens: int
total_tokens: int
+ cost_micros: int = 0
+ call_kind: str = "chat"
@dataclass
@@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
+ cost_micros: int = 0,
+ call_kind: str = "chat",
) -> None:
self.calls.append(
TokenCallRecord(
@@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ call_kind=call_kind,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
- """Return token counts grouped by model name."""
+ """Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
- {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
+ {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ "cost_micros": 0,
+ },
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
+ entry["cost_micros"] += c.cost_micros
return by_model
@property
@@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
+ @property
+ def total_cost_micros(self) -> int:
+ """Sum of per-call ``cost_micros`` across the entire turn.
+
+ Used by ``stream_new_chat`` to debit a premium turn's actual
+ provider cost (in micro-USD) from the user's premium credit
+ balance. ``cost_micros`` per call is captured by
+ ``TokenTrackingCallback.async_log_success_event`` from
+ ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
+ with multiple fallback paths so OpenRouter dynamic models and
+ custom Azure deployments still bill correctly when our
+ ``pricing_registration`` ran at startup.
+ """
+ return sum(c.cost_micros for c in self.calls)
+
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
@@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator:
- """Create a fresh accumulator for the current async context and return it."""
+ """Create a fresh accumulator for the current async context and return it.
+
+ NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
+ short-lived per-call billable wrappers (image generation REST endpoint,
+ vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
+ ContextVar reset token to restore the *previous* accumulator on exit and
+ avoids leaking call records across reservations (issue B).
+ """
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
+@asynccontextmanager
+async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
+ """Async context manager that scopes a fresh ``TurnTokenAccumulator``
+ for the duration of the ``async with`` block, then *resets* the
+ ContextVar to its previous value on exit.
+
+ This is the safe primitive for per-call billable operations
+ (image generation, vision LLM extraction, podcasts) that may run
+ inside an outer chat turn or be called sequentially from the same
+ background worker. Using ``ContextVar.set`` without ``reset`` (as
+ :func:`start_turn` does) would leak the inner accumulator into the
+ outer scope, causing the outer chat turn to debit cost twice.
+
+ Usage::
+
+ async with scoped_turn() as acc:
+ await llm.ainvoke(...)
+ # acc.total_cost_micros captures cost from the LiteLLM callback
+ # Outer accumulator (if any) is restored here.
+ """
+ acc = TurnTokenAccumulator()
+ token = _turn_accumulator.set(acc)
+ logger.debug(
+ "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
+ id(acc),
+ token,
+ )
+ try:
+ yield acc
+ finally:
+ _turn_accumulator.reset(token)
+ logger.debug(
+ "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
+ id(acc),
+ len(acc.calls),
+ acc.total_cost_micros,
+ )
+
+
+def _extract_cost_usd(
+ kwargs: dict[str, Any],
+ response_obj: Any,
+ model: str,
+ prompt_tokens: int,
+ completion_tokens: int,
+ is_image: bool = False,
+) -> float:
+ """Best-effort USD cost extraction for a single LLM/image call.
+
+ Tries four sources in priority order and returns the first that
+ yields a positive number; returns 0.0 if all four fail (the call
+ will then debit nothing from the user's balance — fail-safe).
+
+ Sources:
+ 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
+ field, populated for ``Router.acompletion`` since PR #12500.
+ 2. ``response_obj._hidden_params["response_cost"]`` — same value
+ exposed on the response itself.
+ 3. ``litellm.completion_cost(completion_response=response_obj)``
+ — recompute from the response and LiteLLM's pricing table.
+ 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
+ — manual fallback for OpenRouter/custom-Azure models that
+ only resolve via aliases registered by
+ ``pricing_registration`` at startup. **Skipped for image
+ responses** — ``cost_per_token`` does not support ``ImageResponse``
+ and would raise; the cost map for image-gen lives in different
+ keys (``output_cost_per_image``) handled by ``completion_cost``.
+ """
+ cost = kwargs.get("response_cost")
+ if cost is not None:
+ try:
+ value = float(cost)
+ except (TypeError, ValueError):
+ value = 0.0
+ if value > 0:
+ return value
+
+ hidden = getattr(response_obj, "_hidden_params", None) or {}
+ if isinstance(hidden, dict):
+ cost = hidden.get("response_cost")
+ if cost is not None:
+ try:
+ value = float(cost)
+ except (TypeError, ValueError):
+ value = 0.0
+ if value > 0:
+ return value
+
+ try:
+ value = float(litellm.completion_cost(completion_response=response_obj))
+ if value > 0:
+ return value
+ except Exception as exc:
+ if is_image:
+ # Image-gen path: OpenRouter's image responses can omit
+ # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
+ # then *raises* (no cost map for OpenRouter image models).
+ # Bail out with a warning rather than falling through to
+ # cost_per_token (which is also incompatible with ImageResponse).
+ logger.warning(
+ "[TokenTracking] completion_cost failed for image model=%s "
+ "(provider may have omitted usage.cost). Debiting 0. "
+ "Cause: %s",
+ model,
+ exc,
+ )
+ return 0.0
+ logger.debug(
+ "[TokenTracking] completion_cost failed for model=%s: %s", model, exc
+ )
+
+ if is_image:
+ # Never call cost_per_token for ImageResponse — keys mismatch and
+ # the function is documented chat-only.
+ return 0.0
+
+ if model and (prompt_tokens > 0 or completion_tokens > 0):
+ try:
+ prompt_cost, completion_cost = litellm.cost_per_token(
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ value = float(prompt_cost) + float(completion_cost)
+ if value > 0:
+ return value
+ except Exception as exc:
+ logger.debug(
+ "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
+ )
+
+ return 0.0
+
+
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
@@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
)
return
+ # Detect image generation responses — they have a different usage
+ # shape (ImageUsage with input_tokens/output_tokens) and require a
+ # different cost-extraction path. We probe by class name to avoid a
+ # hard import dependency on litellm internals.
+ response_cls = type(response_obj).__name__
+ is_image = response_cls == "ImageResponse"
+
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug(
@@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
)
return
- prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
- completion_tokens = getattr(usage, "completion_tokens", 0) or 0
- total_tokens = getattr(usage, "total_tokens", 0) or 0
+ if is_image:
+ # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
+ # (not prompt_tokens/completion_tokens). Several providers
+ # populate only one or neither (e.g. OpenRouter's gpt-image-1
+ # passes through `input_tokens` from the prompt but no
+ # completion); fall through gracefully to 0.
+ prompt_tokens = getattr(usage, "input_tokens", 0) or 0
+ completion_tokens = getattr(usage, "output_tokens", 0) or 0
+ total_tokens = (
+ getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
+ )
+ call_kind = "image_generation"
+ else:
+ prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
+ completion_tokens = getattr(usage, "completion_tokens", 0) or 0
+ total_tokens = getattr(usage, "total_tokens", 0) or 0
+ call_kind = "chat"
model = kwargs.get("model", "unknown")
+ cost_usd = _extract_cost_usd(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ is_image=is_image,
+ )
+ cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
+
+ if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
+ logger.warning(
+ "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
+ "kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
+ "input_cost_per_token/output_cost_per_token (or rely on response_cost "
+ "for image generation).",
+ model,
+ prompt_tokens,
+ completion_tokens,
+ call_kind,
+ )
+
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ call_kind=call_kind,
)
logger.info(
- "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
+ "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
+ "cost=$%.6f (%d micros) (accumulator now has %d calls)",
model,
+ call_kind,
prompt_tokens,
completion_tokens,
total_tokens,
+ cost_usd,
+ cost_micros,
len(acc.calls),
)
@@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
+ cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
@@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
@@ -194,11 +416,12 @@ async def record_token_usage(
)
session.add(record)
logger.debug(
- "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
+ "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type,
prompt_tokens,
completion_tokens,
total_tokens,
+ cost_micros,
)
return record
except Exception:
diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py
index 0d782ab2b..ed5de921c 100644
--- a/surfsense_backend/app/services/vision_llm_router_service.py
+++ b/surfsense_backend/app/services/vision_llm_router_service.py
@@ -3,6 +3,8 @@ from typing import Any
from litellm import Router
+from app.services.provider_api_base import resolve_api_base
+
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
@@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
+ provider = config.get("provider", "").upper()
if config.get("custom_provider"):
- model_string = f"{config['custom_provider']}/{config['model_name']}"
+ provider_prefix = config["custom_provider"]
+ model_string = f"{provider_prefix}/{config['model_name']}"
else:
- provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"),
}
- if config.get("api_base"):
- litellm_params["api_base"] = config["api_base"]
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
+ if api_base:
+ litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py
index 5b1f2cd13..b23359f36 100644
--- a/surfsense_backend/app/tasks/celery_tasks/__init__.py
+++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py
@@ -1,10 +1,25 @@
-"""Celery tasks package."""
+"""Celery tasks package.
+
+Also hosts the small helpers every async celery task should use to
+spin up its event loop. See :func:`run_async_celery_task` for the
+canonical pattern.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import logging
+from collections.abc import Awaitable, Callable
+from typing import TypeVar
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
+logger = logging.getLogger(__name__)
+
_celery_engine = None
_celery_session_maker = None
@@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
_celery_engine, expire_on_commit=False
)
return _celery_session_maker
+
+
+def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
+ """Drop the shared ``app.db.engine`` connection pool synchronously.
+
+ The shared engine (used by ``shielded_async_session`` and most
+ routes / services) is a module-level singleton with a real pool.
+ Each celery task creates a fresh ``asyncio`` event loop; asyncpg
+ connections cache a reference to whichever loop opened them. When
+ a subsequent task's loop pulls a stale connection from the pool,
+ SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
+
+ AttributeError: 'NoneType' object has no attribute 'send'
+ File ".../asyncio/proactor_events.py", line 402, in _loop_writing
+ self._write_fut = self._loop._proactor.send(self._sock, data)
+
+ or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
+ coroutine that can never run because its loop is gone.
+
+ Disposing the engine forces the pool to drop every cached
+ connection so the next checkout opens a fresh one on the current
+ loop. Safe to call from a task's finally block; failure is logged
+ but never propagated.
+ """
+ try:
+ from app.db import engine as shared_engine
+
+ loop.run_until_complete(shared_engine.dispose())
+ except Exception:
+ logger.warning("Shared DB engine dispose() failed", exc_info=True)
+
+
+T = TypeVar("T")
+
+
+def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
+ """Run an async coroutine inside a fresh event loop with proper
+ DB-engine cleanup.
+
+ This is the canonical entry point for every async celery task.
+ It performs three responsibilities that were previously copy-pasted
+ (incorrectly) across each task module:
+
+ 1. Create a fresh ``asyncio`` loop and install it on the current
+ thread (celery's ``--pool=solo`` runs every task on the main
+ thread, but other pool types don't).
+ 2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
+ any stale connections left over from a previous task's loop
+ are dropped — defends against tasks that crashed without
+ cleaning up.
+ 3. Dispose the shared engine AFTER the task runs so the
+ connections we opened on this loop are released before the
+ loop closes (avoids ``coroutine 'Connection._cancel' was
+ never awaited`` warnings and the next-task hang).
+
+ Use as::
+
+ @celery_app.task(name="my_task", bind=True)
+ def my_task(self, *args):
+ return run_async_celery_task(lambda: _my_task_impl(*args))
+ """
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ # Defense-in-depth: prior task may have crashed before
+ # disposing. Idempotent — no-op if pool is already empty.
+ _dispose_shared_db_engine(loop)
+ return loop.run_until_complete(coro_factory())
+ finally:
+ # Drop any connections this task opened so they don't leak
+ # into the next task's loop.
+ _dispose_shared_db_engine(loop)
+ with contextlib.suppress(Exception):
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ with contextlib.suppress(Exception):
+ asyncio.set_event_loop(None)
+ loop.close()
+
+
+__all__ = [
+ "get_celery_session_maker",
+ "run_async_celery_task",
+]
diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
index fe1ac19d3..08d96cfa0 100644
--- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
@@ -4,7 +4,7 @@ import logging
import traceback
from app.celery_app import celery_app
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -49,22 +49,15 @@ def index_notion_pages_task(
end_date: str,
):
"""Celery task to index Notion pages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _index_notion_pages(
+ return run_async_celery_task(
+ lambda: _index_notion_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id)
raise
- finally:
- loop.close()
async def _index_notion_pages(
@@ -95,19 +88,11 @@ def index_github_repos_task(
end_date: str,
):
"""Celery task to index GitHub repositories."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_github_repos(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_github_repos(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_github_repos(
@@ -138,19 +123,11 @@ def index_confluence_pages_task(
end_date: str,
):
"""Celery task to index Confluence pages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_confluence_pages(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_confluence_pages(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_confluence_pages(
@@ -181,22 +158,15 @@ def index_google_calendar_events_task(
end_date: str,
):
"""Celery task to index Google Calendar events."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _index_google_calendar_events(
+ return run_async_celery_task(
+ lambda: _index_google_calendar_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise
- finally:
- loop.close()
async def _index_google_calendar_events(
@@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
end_date: str,
):
"""Celery task to index Google Gmail messages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_google_gmail_messages(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_google_gmail_messages(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_google_gmail_messages(
@@ -269,22 +231,14 @@ def index_google_drive_files_task(
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
):
"""Celery task to index Google Drive folders and files."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_google_drive_files(
- connector_id,
- search_space_id,
- user_id,
- items_dict,
- )
+ return run_async_celery_task(
+ lambda: _index_google_drive_files(
+ connector_id,
+ search_space_id,
+ user_id,
+ items_dict,
)
- finally:
- loop.close()
+ )
async def _index_google_drive_files(
@@ -317,22 +271,14 @@ def index_onedrive_files_task(
items_dict: dict,
):
"""Celery task to index OneDrive folders and files."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_onedrive_files(
- connector_id,
- search_space_id,
- user_id,
- items_dict,
- )
+ return run_async_celery_task(
+ lambda: _index_onedrive_files(
+ connector_id,
+ search_space_id,
+ user_id,
+ items_dict,
)
- finally:
- loop.close()
+ )
async def _index_onedrive_files(
@@ -365,22 +311,14 @@ def index_dropbox_files_task(
items_dict: dict,
):
"""Celery task to index Dropbox folders and files."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_dropbox_files(
- connector_id,
- search_space_id,
- user_id,
- items_dict,
- )
+ return run_async_celery_task(
+ lambda: _index_dropbox_files(
+ connector_id,
+ search_space_id,
+ user_id,
+ items_dict,
)
- finally:
- loop.close()
+ )
async def _index_dropbox_files(
@@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
end_date: str,
):
"""Celery task to index Elasticsearch documents."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_elasticsearch_documents(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_elasticsearch_documents(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_elasticsearch_documents(
@@ -457,22 +387,15 @@ def index_crawled_urls_task(
end_date: str,
):
"""Celery task to index Web page Urls."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _index_crawled_urls(
+ return run_async_celery_task(
+ lambda: _index_crawled_urls(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise
- finally:
- loop.close()
async def _index_crawled_urls(
@@ -503,19 +426,11 @@ def index_bookstack_pages_task(
end_date: str,
):
"""Celery task to index BookStack pages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_bookstack_pages(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_bookstack_pages(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_bookstack_pages(
@@ -546,19 +461,11 @@ def index_composio_connector_task(
end_date: str | None,
):
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_composio_connector(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_composio_connector(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_composio_connector(
diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
index c2dbe7700..5d6bde6c1 100644
--- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
@@ -11,7 +11,7 @@ from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
document_id: ID of document to reindex
user_id: ID of user who edited the document
"""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_reindex_document(document_id, user_id))
- finally:
- loop.close()
+ return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
async def _reindex_document(document_id: int, user_id: str):
diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
index 9d12f91f6..c78e376bd 100644
--- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
@@ -11,7 +11,7 @@ from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.tasks.connector_indexers.local_folder_indexer import (
index_local_folder,
index_uploaded_files,
@@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
)
def delete_document_task(self, document_id: int):
"""Celery task to delete a document and its chunks in batches."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(_delete_document_background(document_id))
- finally:
- loop.close()
+ return run_async_celery_task(lambda: _delete_document_background(document_id))
async def _delete_document_background(document_id: int) -> None:
@@ -153,14 +148,9 @@ def delete_folder_documents_task(
folder_subtree_ids: list[int] | None = None,
):
"""Celery task to delete documents first, then the folder rows."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _delete_folder_documents(document_ids, folder_subtree_ids)
- )
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
+ )
async def _delete_folder_documents(
@@ -209,12 +199,9 @@ async def _delete_folder_documents(
)
def delete_search_space_task(self, search_space_id: int):
"""Celery task to delete a search space and heavy child rows in batches."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(_delete_search_space_background(search_space_id))
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _delete_search_space_background(search_space_id)
+ )
async def _delete_search_space_background(search_space_id: int) -> None:
@@ -269,18 +256,11 @@ def process_extension_document_task(
search_space_id: ID of the search space
user_id: ID of the user
"""
- # Create a new event loop for this task
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _process_extension_document(
- individual_document_dict, search_space_id, user_id
- )
+ return run_async_celery_task(
+ lambda: _process_extension_document(
+ individual_document_dict, search_space_id, user_id
)
- finally:
- loop.close()
+ )
async def _process_extension_document(
@@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
search_space_id: ID of the search space
user_id: ID of the user
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _process_youtube_video(url, search_space_id, user_id)
+ )
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
@@ -573,12 +549,9 @@ def process_file_upload_task(
except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}")
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _process_file_upload(file_path, filename, search_space_id, user_id)
+ run_async_celery_task(
+ lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
)
logger.info(
f"[process_file_upload] Task completed successfully for: {filename}"
@@ -589,8 +562,6 @@ def process_file_upload_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
- finally:
- loop.close()
async def _process_file_upload(
@@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
"File may have been removed before syncing could start."
)
# Mark document as failed since file is missing
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _mark_document_failed(
- document_id,
- "File not found. Please re-upload the file.",
- )
+ run_async_celery_task(
+ lambda: _mark_document_failed(
+ document_id,
+ "File not found. Please re-upload the file.",
)
- finally:
- loop.close()
+ )
return
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _process_file_with_document(
+ run_async_celery_task(
+ lambda: _process_file_with_document(
document_id,
temp_path,
filename,
@@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
- finally:
- loop.close()
async def _mark_document_failed(document_id: int, reason: str):
@@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _process_circleback_meeting(
- meeting_id,
- meeting_name,
- markdown_content,
- metadata,
- search_space_id,
- connector_id,
- )
+ return run_async_celery_task(
+ lambda: _process_circleback_meeting(
+ meeting_id,
+ meeting_name,
+ markdown_content,
+ metadata,
+ search_space_id,
+ connector_id,
)
- finally:
- loop.close()
+ )
async def _process_circleback_meeting(
@@ -1291,25 +1246,19 @@ def index_local_folder_task(
target_file_paths: list[str] | None = None,
):
"""Celery task to index a local folder. Config is passed directly — no connector row."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_local_folder_async(
- search_space_id=search_space_id,
- user_id=user_id,
- folder_path=folder_path,
- folder_name=folder_name,
- exclude_patterns=exclude_patterns,
- file_extensions=file_extensions,
- root_folder_id=root_folder_id,
- enable_summary=enable_summary,
- target_file_paths=target_file_paths,
- )
+ return run_async_celery_task(
+ lambda: _index_local_folder_async(
+ search_space_id=search_space_id,
+ user_id=user_id,
+ folder_path=folder_path,
+ folder_name=folder_name,
+ exclude_patterns=exclude_patterns,
+ file_extensions=file_extensions,
+ root_folder_id=root_folder_id,
+ enable_summary=enable_summary,
+ target_file_paths=target_file_paths,
)
- finally:
- loop.close()
+ )
async def _index_local_folder_async(
@@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
processing_mode: str = "basic",
):
"""Celery task to index files uploaded from the desktop app."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _index_uploaded_folder_files_async(
- search_space_id=search_space_id,
- user_id=user_id,
- folder_name=folder_name,
- root_folder_id=root_folder_id,
- enable_summary=enable_summary,
- file_mappings=file_mappings,
- use_vision_llm=use_vision_llm,
- processing_mode=processing_mode,
- )
+ return run_async_celery_task(
+ lambda: _index_uploaded_folder_files_async(
+ search_space_id=search_space_id,
+ user_id=user_id,
+ folder_name=folder_name,
+ root_folder_id=root_folder_id,
+ enable_summary=enable_summary,
+ file_mappings=file_mappings,
+ use_vision_llm=use_vision_llm,
+ processing_mode=processing_mode,
)
- finally:
- loop.close()
+ )
async def _index_uploaded_folder_files_async(
@@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _ai_sort_search_space_async(search_space_id, user_id)
+ )
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
@@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _ai_sort_document_async(search_space_id, user_id, document_id)
- )
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
+ )
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
diff --git a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py
index 98b107af3..c6c8666f5 100644
--- a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py
@@ -2,14 +2,13 @@
from __future__ import annotations
-import asyncio
import logging
from app.celery_app import celery_app
from app.db import SearchSourceConnector
from app.schemas.obsidian_plugin import NotePayload
from app.services.obsidian_plugin_indexer import upsert_note
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
user_id: str,
) -> None:
"""Process one Obsidian non-markdown attachment asynchronously."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _index_obsidian_attachment(
- connector_id=connector_id,
- payload_data=payload_data,
- user_id=user_id,
- )
+ return run_async_celery_task(
+ lambda: _index_obsidian_attachment(
+ connector_id=connector_id,
+ payload_data=payload_data,
+ user_id=user_id,
)
- finally:
- loop.close()
+ )
async def _index_obsidian_attachment(
diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
index 953011ecf..8b311576e 100644
--- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
@@ -3,14 +3,22 @@
import asyncio
import logging
import sys
+from contextlib import asynccontextmanager
from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
+from app.config import config as app_config
from app.db import Podcast, PodcastStatus
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.services.billable_calls import (
+ BillingSettlementError,
+ QuotaInsufficientError,
+ _resolve_agent_billing_for_search_space,
+ billable_call,
+)
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -28,6 +36,13 @@ if sys.platform.startswith("win"):
# =============================================================================
+@asynccontextmanager
+async def _celery_billable_session():
+ """Session factory used by billable_call inside the Celery worker loop."""
+ async with get_celery_session_maker()() as session:
+ yield session
+
+
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
@@ -40,27 +55,22 @@ def generate_content_podcast_task(
Celery task to generate podcast from source content.
Updates existing podcast record created by the tool.
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- result = loop.run_until_complete(
- _generate_content_podcast(
+ return run_async_celery_task(
+ lambda: _generate_content_podcast(
podcast_id,
source_content,
search_space_id,
user_prompt,
)
)
- loop.run_until_complete(loop.shutdown_asyncgens())
- return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
- loop.run_until_complete(_mark_podcast_failed(podcast_id))
+ try:
+ run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
+ except Exception:
+ logger.exception("Failed to mark podcast %s as failed", podcast_id)
return {"status": "failed", "podcast_id": podcast_id}
- finally:
- asyncio.set_event_loop(None)
- loop.close()
async def _mark_podcast_failed(podcast_id: int) -> None:
@@ -96,6 +106,31 @@ async def _generate_content_podcast(
podcast.status = PodcastStatus.GENERATING
await session.commit()
+ try:
+ (
+ owner_user_id,
+ billing_tier,
+ base_model,
+ ) = await _resolve_agent_billing_for_search_space(
+ session,
+ search_space_id,
+ thread_id=podcast.thread_id,
+ )
+ except ValueError as resolve_err:
+ logger.error(
+ "Podcast %s: cannot resolve billing for search_space=%s: %s",
+ podcast.id,
+ search_space_id,
+ resolve_err,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "billing_resolution_failed",
+ }
+
graph_config = {
"configurable": {
"podcast_title": podcast.title,
@@ -109,9 +144,52 @@ async def _generate_content_podcast(
db_session=session,
)
- graph_result = await podcaster_graph.ainvoke(
- initial_state, config=graph_config
- )
+ try:
+ async with billable_call(
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
+ usage_type="podcast_generation",
+ call_details={
+ "podcast_id": podcast.id,
+ "title": podcast.title,
+ "thread_id": podcast.thread_id,
+ },
+ billable_session_factory=_celery_billable_session,
+ ):
+ graph_result = await podcaster_graph.ainvoke(
+ initial_state, config=graph_config
+ )
+ except QuotaInsufficientError as exc:
+ logger.info(
+ "Podcast %s denied: out of premium credits "
+ "(used=%d/%d remaining=%d)",
+ podcast.id,
+ exc.used_micros,
+ exc.limit_micros,
+ exc.remaining_micros,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "premium_quota_exhausted",
+ }
+ except BillingSettlementError:
+ logger.exception(
+ "Podcast %s: premium billing settlement failed",
+ podcast.id,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "billing_settlement_failed",
+ }
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
@@ -133,7 +211,14 @@ async def _generate_content_podcast(
podcast.podcast_transcript = serializable_transcript
podcast.file_location = file_path
podcast.status = PodcastStatus.READY
+ logger.info(
+ "Podcast %s: committing READY transcript_entries=%d file=%s",
+ podcast.id,
+ len(serializable_transcript),
+ file_path,
+ )
await session.commit()
+ logger.info("Podcast %s: READY commit complete", podcast.id)
logger.info(f"Successfully generated podcast: {podcast.id}")
diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
index 373f04b48..e41251407 100644
--- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
@@ -7,7 +7,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
@@ -20,15 +20,7 @@ def check_periodic_schedules_task():
This task runs every minute and triggers indexing for any connector
whose next_scheduled_at time has passed.
"""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_check_and_trigger_schedules())
- finally:
- loop.close()
+ return run_async_celery_task(_check_and_trigger_schedules)
async def _check_and_trigger_schedules():
diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
index e05ae9435..d51c85dee 100644
--- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
@@ -34,7 +34,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
Also marks associated pending/processing documents as failed.
"""
- import asyncio
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ async def _both() -> None:
+ await _cleanup_stale_notifications()
+ await _cleanup_stale_document_processing_notifications()
- try:
- loop.run_until_complete(_cleanup_stale_notifications())
- loop.run_until_complete(_cleanup_stale_document_processing_notifications())
- finally:
- loop.close()
+ return run_async_celery_task(_both)
async def _cleanup_stale_notifications():
diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py
index 3aee1a360..ace6ef7ca 100644
--- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-import asyncio
import logging
from datetime import UTC, datetime, timedelta
@@ -18,7 +17,7 @@ from app.db import (
PremiumTokenPurchaseStatus,
)
from app.routes import stripe_routes
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
def reconcile_pending_stripe_page_purchases_task():
"""Recover paid purchases that were left pending due to missed webhook handling."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_reconcile_pending_page_purchases())
- finally:
- loop.close()
+ return run_async_celery_task(_reconcile_pending_page_purchases)
async def _reconcile_pending_page_purchases() -> None:
@@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
def reconcile_pending_stripe_token_purchases_task():
"""Recover paid token purchases that were left pending due to missed webhook handling."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_reconcile_pending_token_purchases())
- finally:
- loop.close()
+ return run_async_celery_task(_reconcile_pending_token_purchases)
async def _reconcile_pending_token_purchases() -> None:
diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
index 7880b385f..08f22140c 100644
--- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
@@ -3,14 +3,22 @@
import asyncio
import logging
import sys
+from contextlib import asynccontextmanager
from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app
+from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.services.billable_calls import (
+ BillingSettlementError,
+ QuotaInsufficientError,
+ _resolve_agent_billing_for_search_space,
+ billable_call,
+)
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -23,6 +31,13 @@ if sys.platform.startswith("win"):
)
+@asynccontextmanager
+async def _celery_billable_session():
+ """Session factory used by billable_call inside the Celery worker loop."""
+ async with get_celery_session_maker()() as session:
+ yield session
+
+
@celery_app.task(name="generate_video_presentation", bind=True)
def generate_video_presentation_task(
self,
@@ -35,27 +50,30 @@ def generate_video_presentation_task(
Celery task to generate video presentation from source content.
Updates existing video presentation record created by the tool.
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- result = loop.run_until_complete(
- _generate_video_presentation(
+ return run_async_celery_task(
+ lambda: _generate_video_presentation(
video_presentation_id,
source_content,
search_space_id,
user_prompt,
)
)
- loop.run_until_complete(loop.shutdown_asyncgens())
- return result
except Exception as e:
logger.error(f"Error generating video presentation: {e!s}")
- loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
+ # Mark FAILED in a fresh loop — the previous loop is closed.
+ # Swallow secondary failures; the row will simply stay in
+ # GENERATING and be flushed by the periodic stale cleanup.
+ try:
+ run_async_celery_task(
+ lambda: _mark_video_presentation_failed(video_presentation_id)
+ )
+ except Exception:
+ logger.exception(
+ "Failed to mark video presentation %s as failed",
+ video_presentation_id,
+ )
return {"status": "failed", "video_presentation_id": video_presentation_id}
- finally:
- asyncio.set_event_loop(None)
- loop.close()
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
@@ -97,6 +115,32 @@ async def _generate_video_presentation(
video_pres.status = VideoPresentationStatus.GENERATING
await session.commit()
+ try:
+ (
+ owner_user_id,
+ billing_tier,
+ base_model,
+ ) = await _resolve_agent_billing_for_search_space(
+ session,
+ search_space_id,
+ thread_id=video_pres.thread_id,
+ )
+ except ValueError as resolve_err:
+ logger.error(
+ "VideoPresentation %s: cannot resolve billing for "
+ "search_space=%s: %s",
+ video_pres.id,
+ search_space_id,
+ resolve_err,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "billing_resolution_failed",
+ }
+
graph_config = {
"configurable": {
"video_title": video_pres.title,
@@ -110,9 +154,52 @@ async def _generate_video_presentation(
db_session=session,
)
- graph_result = await video_presentation_graph.ainvoke(
- initial_state, config=graph_config
- )
+ try:
+ async with billable_call(
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
+ usage_type="video_presentation_generation",
+ call_details={
+ "video_presentation_id": video_pres.id,
+ "title": video_pres.title,
+ "thread_id": video_pres.thread_id,
+ },
+ billable_session_factory=_celery_billable_session,
+ ):
+ graph_result = await video_presentation_graph.ainvoke(
+ initial_state, config=graph_config
+ )
+ except QuotaInsufficientError as exc:
+ logger.info(
+ "VideoPresentation %s denied: out of premium credits "
+ "(used=%d/%d remaining=%d)",
+ video_pres.id,
+ exc.used_micros,
+ exc.limit_micros,
+ exc.remaining_micros,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "premium_quota_exhausted",
+ }
+ except BillingSettlementError:
+ logger.exception(
+ "VideoPresentation %s: premium billing settlement failed",
+ video_pres.id,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "billing_settlement_failed",
+ }
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])
@@ -143,7 +230,14 @@ async def _generate_video_presentation(
video_pres.slides = serializable_slides
video_pres.scene_codes = serializable_scene_codes
video_pres.status = VideoPresentationStatus.READY
+ logger.info(
+ "VideoPresentation %s: committing READY slides=%d scene_codes=%d",
+ video_pres.id,
+ len(serializable_slides),
+ len(serializable_scene_codes),
+ )
await session.commit()
+ logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
logger.info(f"Successfully generated video presentation: {video_pres.id}")
diff --git a/surfsense_backend/app/tasks/chat/content_builder.py b/surfsense_backend/app/tasks/chat/content_builder.py
new file mode 100644
index 000000000..041cab286
--- /dev/null
+++ b/surfsense_backend/app/tasks/chat/content_builder.py
@@ -0,0 +1,515 @@
+"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
+
+Background
+----------
+The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
+SSE events that the frontend folds into a ``ContentPartsState`` (see
+``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
+``stream-pipeline.ts``). When a turn ends, the frontend calls
+``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
+JSONB to ``POST /threads/{id}/messages``, which is what was historically
+written to ``new_chat_messages.content``.
+
+After the ghost-thread fix moved persistence server-side, the assistant
+row is written by ``finalize_assistant_turn`` in the streaming finally
+block. The frontend's later ``appendMessage`` is now a no-op (recovers
+via the ``(thread_id, turn_id, role)`` partial unique index added in
+migration 141), which means the *server* is now responsible for
+producing the rich ``ContentPart[]`` shape the FE expects on history
+reload — text + reasoning + tool-call cards (with ``args``, ``argsText``,
+``result``, ``langchainToolCallId``) + thinking-step buckets +
+step-separators.
+
+This module is the in-memory accumulator that mirrors the FE state for
+exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
+/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
+``mark_interrupted`` at the same call sites it yields the matching
+``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
+stays in lockstep with what the FE's pipeline would have produced live.
+``snapshot()`` is then taken once in the ``finally`` block and persisted
+in a single UPDATE.
+
+Pure synchronous state — no DB I/O, no async, no flush callbacks. The
+streaming code is responsible for driving lifecycle methods; this class
+is a thin projection helper.
+"""
+
+from __future__ import annotations
+
+import copy
+import json
+import logging
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+
+# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
+# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
+# and data-step-separator decorate the meaningful parts but never stand alone
+# in a successful turn.
+_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
+
+
+class AssistantContentBuilder:
+ """Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
+
+ Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
+ matches the FE ``ContentPart`` union::
+
+ | { type: "text"; text: string }
+ | { type: "reasoning"; text: string }
+ | { type: "tool-call"; toolCallId: str; toolName: str;
+ args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
+ state?: "aborted" }
+ | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
+ | { type: "data-step-separator"; data: { stepIndex: int } }
+
+ Order matches the wire order of the SSE events that drive the lifecycle
+ methods, with two FE-mirrored exceptions:
+
+ 1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
+ first time we see a ``data-thinking-step`` SSE event (the FE's
+ ``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
+ thinking-step updates mutate that singleton in place.
+ 2. ``data-step-separator`` is appended only when the message already has
+ meaningful content and the previous part isn't itself a separator
+ (so the FIRST step of a turn doesn't generate a leading divider).
+ """
+
+ def __init__(self) -> None:
+ self.parts: list[dict[str, Any]] = []
+ # Index of the active text/reasoning part within ``parts`` while
+ # streaming is open; -1 means "no active part" and the next delta
+ # opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
+ self._current_text_idx: int = -1
+ self._current_reasoning_idx: int = -1
+ # ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
+ # synthetic ``call_`` (legacy) or the LangChain
+ # ``tool_call.id`` (parity_v2) — same key the streaming layer
+ # threads through every ``tool-input-*`` / ``tool-output-*`` event.
+ self._tool_call_idx_by_ui_id: dict[str, int] = {}
+ # Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
+ # so we can reproduce the FE's ``appendToolInputDelta`` behaviour
+ # before ``tool-input-available`` overwrites it with the
+ # pretty-printed final JSON.
+ self._args_text_by_ui_id: dict[str, str] = {}
+
+ # ------------------------------------------------------------------
+ # Text
+ # ------------------------------------------------------------------
+
+ def on_text_start(self, text_id: str) -> None:
+ """Begin a fresh text block.
+
+ Symmetric to FE ``appendText``: opening text closes any active
+ reasoning so the renderer treats them as separate parts. The
+ actual text part isn't materialised here — it's lazily created
+ on the first ``on_text_delta`` so an empty start/end pair
+ leaves no trace. Matches the FE pipeline which has no explicit
+ ``text-start`` handler at all.
+ """
+ if self._current_reasoning_idx >= 0:
+ self._current_reasoning_idx = -1
+
+ def on_text_delta(self, text_id: str, delta: str) -> None:
+ if not delta:
+ return
+ if self._current_reasoning_idx >= 0:
+ # FE behaviour: a text delta after reasoning implicitly
+ # closes the reasoning block (see ``appendText`` lines
+ # 178-180).
+ self._current_reasoning_idx = -1
+ if (
+ self._current_text_idx >= 0
+ and 0 <= self._current_text_idx < len(self.parts)
+ and self.parts[self._current_text_idx].get("type") == "text"
+ ):
+ self.parts[self._current_text_idx]["text"] += delta
+ return
+ self.parts.append({"type": "text", "text": delta})
+ self._current_text_idx = len(self.parts) - 1
+
+ def on_text_end(self, text_id: str) -> None:
+ """Close the active text block.
+
+ Mirrors the wire-level ``text-end`` boundary the streaming layer
+ emits before tool calls / reasoning / step boundaries. The FE
+ pipeline implicitly closes via ``currentTextPartIndex = -1``
+ in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
+ our helper does the same explicitly so callers don't have to
+ maintain that invariant per call site.
+ """
+ self._current_text_idx = -1
+
+ # ------------------------------------------------------------------
+ # Reasoning
+ # ------------------------------------------------------------------
+
+ def on_reasoning_start(self, reasoning_id: str) -> None:
+ if self._current_text_idx >= 0:
+ self._current_text_idx = -1
+
+ def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
+ if not delta:
+ return
+ if self._current_text_idx >= 0:
+ self._current_text_idx = -1
+ if (
+ self._current_reasoning_idx >= 0
+ and 0 <= self._current_reasoning_idx < len(self.parts)
+ and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
+ ):
+ self.parts[self._current_reasoning_idx]["text"] += delta
+ return
+ self.parts.append({"type": "reasoning", "text": delta})
+ self._current_reasoning_idx = len(self.parts) - 1
+
+ def on_reasoning_end(self, reasoning_id: str) -> None:
+ self._current_reasoning_idx = -1
+
+ # ------------------------------------------------------------------
+ # Tool calls
+ # ------------------------------------------------------------------
+
+ def on_tool_input_start(
+ self,
+ ui_id: str,
+ tool_name: str,
+ langchain_tool_call_id: str | None,
+ ) -> None:
+ """Register a tool-call card. Args are filled in by later events."""
+ if not ui_id:
+ return
+ # Skip duplicate registration: parity_v2 may emit
+ # ``tool-input-start`` from both ``on_chat_model_stream``
+ # (when tool_call_chunks register a name) and ``on_tool_start``
+ # (the canonical path). The FE de-dupes via ``toolCallIndices``;
+ # we mirror that here.
+ if ui_id in self._tool_call_idx_by_ui_id:
+ if langchain_tool_call_id:
+ idx = self._tool_call_idx_by_ui_id[ui_id]
+ part = self.parts[idx]
+ if not part.get("langchainToolCallId"):
+ part["langchainToolCallId"] = langchain_tool_call_id
+ return
+
+ part: dict[str, Any] = {
+ "type": "tool-call",
+ "toolCallId": ui_id,
+ "toolName": tool_name,
+ "args": {},
+ }
+ if langchain_tool_call_id:
+ part["langchainToolCallId"] = langchain_tool_call_id
+ self.parts.append(part)
+ self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
+
+ self._current_text_idx = -1
+ self._current_reasoning_idx = -1
+
+ def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
+ """Append a streamed args-delta chunk to the matching card's argsText.
+
+ Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
+ registered yet for the given ``ui_id`` — the deltas have nowhere
+ safe to land.
+ """
+ if not ui_id or not args_chunk:
+ return
+ idx = self._tool_call_idx_by_ui_id.get(ui_id)
+ if idx is None:
+ return
+ if not (0 <= idx < len(self.parts)):
+ return
+ part = self.parts[idx]
+ if part.get("type") != "tool-call":
+ return
+ new_text = (part.get("argsText") or "") + args_chunk
+ part["argsText"] = new_text
+ self._args_text_by_ui_id[ui_id] = new_text
+
+ def on_tool_input_available(
+ self,
+ ui_id: str,
+ tool_name: str,
+ args: dict[str, Any],
+ langchain_tool_call_id: str | None,
+ ) -> None:
+ """Finalize the tool-call card's input.
+
+ Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
+ with ``json.dumps(input, indent=2)`` so the post-stream card renders
+ pretty-printed JSON, sets the full ``args`` dict, and backfills
+ ``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
+ Also creates the card if no prior ``tool-input-start`` registered it
+ (legacy parity_v2-OFF / late-registration paths).
+ """
+ if not ui_id:
+ return
+ try:
+ final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
+ except (TypeError, ValueError):
+ # Defensive: ``args`` should already be JSON-safe (the
+ # streaming layer sanitizes it before emitting), but if a
+ # caller hands us a non-serializable value we still want
+ # to record the call without breaking the snapshot.
+ final_args_text = str(args)
+
+ idx = self._tool_call_idx_by_ui_id.get(ui_id)
+ if idx is not None and 0 <= idx < len(self.parts):
+ part = self.parts[idx]
+ if part.get("type") == "tool-call":
+ part["args"] = args or {}
+ part["argsText"] = final_args_text
+ if langchain_tool_call_id and not part.get("langchainToolCallId"):
+ part["langchainToolCallId"] = langchain_tool_call_id
+ return
+
+ # No prior tool-input-start: register the card now.
+ new_part: dict[str, Any] = {
+ "type": "tool-call",
+ "toolCallId": ui_id,
+ "toolName": tool_name,
+ "args": args or {},
+ "argsText": final_args_text,
+ }
+ if langchain_tool_call_id:
+ new_part["langchainToolCallId"] = langchain_tool_call_id
+ self.parts.append(new_part)
+ self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
+
+ self._current_text_idx = -1
+ self._current_reasoning_idx = -1
+
+ def on_tool_output_available(
+ self,
+ ui_id: str,
+ output: Any,
+ langchain_tool_call_id: str | None,
+ ) -> None:
+ """Attach the tool's output (``result``) to the matching card.
+
+ Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
+ only if not already set (a NULL late-arriving value never blows
+ away an earlier known good one).
+ """
+ if not ui_id:
+ return
+ idx = self._tool_call_idx_by_ui_id.get(ui_id)
+ if idx is None or not (0 <= idx < len(self.parts)):
+ return
+ part = self.parts[idx]
+ if part.get("type") != "tool-call":
+ return
+ part["result"] = output
+ if langchain_tool_call_id and not part.get("langchainToolCallId"):
+ part["langchainToolCallId"] = langchain_tool_call_id
+
+ # ------------------------------------------------------------------
+ # Thinking steps & step separators
+ # ------------------------------------------------------------------
+
+ def on_thinking_step(
+ self,
+ step_id: str,
+ title: str,
+ status: str,
+ items: list[str] | None,
+ ) -> None:
+ """Update / insert the singleton ``data-thinking-steps`` part.
+
+ Mirrors FE ``updateThinkingSteps``: maintain a single
+ ``data-thinking-steps`` part anchored at index 0, replacing or
+ unshifting on first sight. Each ``on_thinking_step`` call
+ replaces the entry in the steps list keyed by ``step_id`` (or
+ appends if new).
+ """
+ if not step_id:
+ return
+
+ new_step = {
+ "id": step_id,
+ "title": title or "",
+ "status": status or "in_progress",
+ "items": list(items) if items else [],
+ }
+
+ # Find existing data-thinking-steps part.
+ existing_idx = -1
+ for i, p in enumerate(self.parts):
+ if p.get("type") == "data-thinking-steps":
+ existing_idx = i
+ break
+
+ if existing_idx >= 0:
+ current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
+ replaced = False
+ for i, step in enumerate(current_steps):
+ if step.get("id") == step_id:
+ current_steps[i] = new_step
+ replaced = True
+ break
+ if not replaced:
+ current_steps.append(new_step)
+ self.parts[existing_idx] = {
+ "type": "data-thinking-steps",
+ "data": {"steps": current_steps},
+ }
+ return
+
+ # First sight: unshift to position 0 (FE parity).
+ self.parts.insert(
+ 0,
+ {
+ "type": "data-thinking-steps",
+ "data": {"steps": [new_step]},
+ },
+ )
+ # Bump tracked indices since we inserted at the head.
+ if self._current_text_idx >= 0:
+ self._current_text_idx += 1
+ if self._current_reasoning_idx >= 0:
+ self._current_reasoning_idx += 1
+ for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
+ self._tool_call_idx_by_ui_id[ui_id] = idx + 1
+
+ def on_step_separator(self) -> None:
+ """Append a ``data-step-separator`` between consecutive model steps.
+
+ Mirrors FE ``addStepSeparator``: only emit when the message
+ already has meaningful content AND the previous part isn't
+ itself a separator. ``stepIndex`` is the running count of
+ separators already in ``parts``.
+ """
+ has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
+ if not has_content:
+ return
+ if self.parts and self.parts[-1].get("type") == "data-step-separator":
+ return
+ step_index = sum(
+ 1 for p in self.parts if p.get("type") == "data-step-separator"
+ )
+ self.parts.append(
+ {
+ "type": "data-step-separator",
+ "data": {"stepIndex": step_index},
+ }
+ )
+ self._current_text_idx = -1
+ self._current_reasoning_idx = -1
+
+ # ------------------------------------------------------------------
+ # Interruption handling
+ # ------------------------------------------------------------------
+
+ def mark_interrupted(self) -> None:
+ """Close any open text/reasoning and flip running tools to aborted.
+
+ Called from the streaming ``finally`` block before ``snapshot()`` so
+ the persisted JSONB reflects a coherent end-state even when the
+ client disconnected mid-turn or the agent hit a fatal error.
+
+ - Active text/reasoning blocks: simply lose their "active"
+ marker (no synthetic content appended). Whatever was streamed
+ stays as-is.
+ - Tool-call parts that never received a ``result`` get
+ ``state="aborted"`` so the FE history loader can render them
+ as "interrupted" rather than "still running".
+ """
+ self._current_text_idx = -1
+ self._current_reasoning_idx = -1
+ for part in self.parts:
+ if part.get("type") != "tool-call":
+ continue
+ if "result" in part:
+ continue
+ part["state"] = "aborted"
+
+ # ------------------------------------------------------------------
+ # Snapshot & introspection
+ # ------------------------------------------------------------------
+
+ def snapshot(self) -> list[dict[str, Any]]:
+ """Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
+
+ Deep-copied so callers that finalize from the shielded ``finally``
+ block can't accidentally mutate the persisted payload while the
+ SQL UPDATE is in flight (the streaming layer doesn't touch the
+ builder after this call, but defensive copies are cheap and cheap
+ is what we want in a finally block).
+ """
+ return copy.deepcopy(self.parts)
+
+ def is_empty(self) -> bool:
+ """True if no meaningful content was captured.
+
+ ``data-thinking-steps`` and ``data-step-separator`` decorate
+ meaningful content but don't count on their own — a turn that
+ only emitted a thinking step before being interrupted should
+ still be treated as empty for the status-marker fallback.
+ """
+ return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
+
+ def stats(self) -> dict[str, int]:
+ """Return counts of each part-type plus rough byte size.
+
+ Used by the streaming layer's perf logger so an ops dashboard
+ can correlate finalize latency with payload size, and so a
+ regression that quietly stops emitting tool-call parts (or
+ starts emitting hundreds) shows up in [PERF] grep rather than
+ only as a "history reload looks weird" bug report.
+
+ ``bytes`` is the JSON-serialised payload length — what actually
+ crosses the wire to PostgreSQL's JSONB column. We compute it
+ with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
+ on-disk layout closely enough for back-of-the-envelope sizing.
+ Reasoning/text/tool-call/thinking-step/step-separator counts are
+ independent so any one can spike without the others.
+
+ Defensive: ``json.dumps`` failure (a non-serializable value
+ slipped past the streaming layer's sanitization) is reported as
+ ``bytes=-1`` rather than raised — perf logging must not be the
+ thing that breaks the streaming finally block.
+ """
+ text_blocks = 0
+ reasoning_blocks = 0
+ tool_calls = 0
+ tool_calls_completed = 0
+ tool_calls_aborted = 0
+ thinking_step_parts = 0
+ step_separators = 0
+
+ for part in self.parts:
+ kind = part.get("type")
+ if kind == "text":
+ text_blocks += 1
+ elif kind == "reasoning":
+ reasoning_blocks += 1
+ elif kind == "tool-call":
+ tool_calls += 1
+ if part.get("state") == "aborted":
+ tool_calls_aborted += 1
+ elif "result" in part:
+ tool_calls_completed += 1
+ elif kind == "data-thinking-steps":
+ thinking_step_parts += 1
+ elif kind == "data-step-separator":
+ step_separators += 1
+
+ try:
+ byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
+ except (TypeError, ValueError):
+ byte_size = -1
+
+ return {
+ "parts": len(self.parts),
+ "bytes": byte_size,
+ "text": text_blocks,
+ "reasoning": reasoning_blocks,
+ "tool_calls": tool_calls,
+ "tool_calls_completed": tool_calls_completed,
+ "tool_calls_aborted": tool_calls_aborted,
+ "thinking_step_parts": thinking_step_parts,
+ "step_separators": step_separators,
+ }
diff --git a/surfsense_backend/app/tasks/chat/persistence.py b/surfsense_backend/app/tasks/chat/persistence.py
new file mode 100644
index 000000000..b2b8b6a88
--- /dev/null
+++ b/surfsense_backend/app/tasks/chat/persistence.py
@@ -0,0 +1,534 @@
+"""Server-side message persistence helpers for the streaming chat agent.
+
+Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
+left ``new_chat_messages`` empty and relied on the frontend to round-trip
+``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
+a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
+without leaving an audit trail. These helpers move both writes (the user
+turn that triggered the stream and the assistant turn the stream produced)
+into the server itself, idempotent against the partial unique index
+``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
+keep posting via ``appendMessage`` simply hit the unique-index recovery
+path on the second writer instead of creating duplicates.
+
+Assistant turn lifecycle
+------------------------
+The assistant side is split into two helpers so we can capture the row id
+*before* the stream produces any output:
+
+* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
+ and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
+ ASSISTANT)``. Returns the row id so the streaming layer can correlate
+ later writes (token_usage, AgentActionLog future-correlation) against
+ a stable PK from the start of the turn.
+* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
+ It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
+ produced server-side by ``AssistantContentBuilder`` and writes the
+ ``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
+ the ``uq_token_usage_message_id`` partial unique index from migration
+ 142, hard-eliminating any race against ``append_message``'s recovery
+ branch.
+
+Defensive contract
+------------------
+
+* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
+ survives starlette's mid-stream cancel scope on client disconnect.
+* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
+ CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
+ role)`` partial unique index. On conflict the insert silently no-ops at
+ the DB level — no Python ``IntegrityError`` is constructed, which
+ eliminates spurious debugger pauses and keeps logs clean. On conflict a
+ follow-up ``SELECT`` resolves the existing row id so the streaming layer
+ can correlate writes against a stable PK.
+* ``finalize_assistant_turn`` is best-effort: it never raises. The
+ streaming ``finally`` block calls it from within
+ ``anyio.CancelScope(shield=True)`` and any raised exception there
+ would mask the real error.
+"""
+
+from __future__ import annotations
+
+import logging
+import time
+from datetime import UTC, datetime
+from typing import Any
+from uuid import UUID
+
+from sqlalchemy import text as sa_text
+from sqlalchemy.dialects.postgresql import insert as pg_insert
+from sqlalchemy.future import select
+
+from app.db import (
+ NewChatMessage,
+ NewChatMessageRole,
+ NewChatThread,
+ TokenUsage,
+ shielded_async_session,
+)
+from app.services.token_tracking_service import (
+ TurnTokenAccumulator,
+)
+from app.utils.perf import get_perf_logger
+
+logger = logging.getLogger(__name__)
+_perf_log = get_perf_logger()
+
+
+# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
+# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
+# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
+# list with an empty text part so a crash between shell-INSERT and finalize
+# leaves the row in a FE-renderable shape (blank bubble) instead of
+# blowing up the history loader.
+_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
+
+# Substituted content for genuinely empty turns (no text, no reasoning,
+# no tool calls). The streaming layer flips to this when
+# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
+# row is at least somewhat self-describing instead of an empty text
+# bubble. The FE's ``ContentPart`` union doesn't include ``status``
+# yet, so the history loader will silently drop this part and render
+# a blank bubble (matches today's behaviour for empty turns); a follow-up
+# FE PR adds the explicit "no response" rendering.
+_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
+ {"type": "status", "text": "(no text response)"}
+]
+
+
+def _build_user_content(
+ user_query: str,
+ user_image_data_urls: list[str] | None,
+ mentioned_documents: list[dict[str, Any]] | None = None,
+) -> list[dict[str, Any]]:
+ """Build the persisted user-message ``content`` (assistant-ui v2 parts).
+
+ Mirrors the shape the existing frontend posts via
+ ``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
+
+ [{"type": "text", "text": "..."},
+ {"type": "image", "image": "data:..."},
+ {"type": "mentioned-documents", "documents": [{"id": int,
+ "title": str, "document_type": str}, ...]}]
+
+ The companion reader is
+ ``app.utils.user_message_multimodal.split_persisted_user_content_parts``
+ which expects exactly this shape — keep them in sync.
+
+ ``mentioned_documents``: optional list of ``{id, title, document_type}``
+ dicts. When non-empty (and a ``mentioned-documents`` part is not already
+ in some other input shape), a single ``{"type": "mentioned-documents",
+ "documents": [...]}`` part is appended. Mirrors the FE injection at
+ ``page.tsx:281-286`` (``persistUserTurn``).
+ """
+ parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
+ for url in user_image_data_urls or ():
+ if isinstance(url, str) and url:
+ parts.append({"type": "image", "image": url})
+ if mentioned_documents:
+ normalized: list[dict[str, Any]] = []
+ for doc in mentioned_documents:
+ if not isinstance(doc, dict):
+ continue
+ doc_id = doc.get("id")
+ title = doc.get("title")
+ document_type = doc.get("document_type")
+ if doc_id is None or title is None or document_type is None:
+ continue
+ normalized.append(
+ {
+ "id": doc_id,
+ "title": str(title),
+ "document_type": str(document_type),
+ }
+ )
+ if normalized:
+ parts.append({"type": "mentioned-documents", "documents": normalized})
+ return parts
+
+
+async def persist_user_turn(
+ *,
+ chat_id: int,
+ user_id: str | None,
+ turn_id: str,
+ user_query: str,
+ user_image_data_urls: list[str] | None = None,
+ mentioned_documents: list[dict[str, Any]] | None = None,
+) -> int | None:
+ """Persist the user-side row for a chat turn and return its ``id``.
+
+ Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
+ ``(thread_id, turn_id, role)`` partial unique index from migration 141
+ (``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
+ at the DB level — no Python ``IntegrityError`` is constructed, which
+ eliminates the debugger pause that ``justMyCode=false`` + async greenlet
+ interactions used to produce, and keeps production logs clean.
+
+ Returns the ``id`` of the row that exists for this turn after the call:
+ the freshly inserted ``id`` on the happy path, or the existing ``id``
+ when a previous writer (legacy FE ``appendMessage`` racing the SSE
+ stream, redelivered request, etc.) already wrote it. Returns ``None``
+ only on genuine DB failure; the caller should yield a streaming error
+ and abort the turn so we never produce a title/assistant row that
+ isn't anchored to a persisted user message.
+
+ Other constraint violations (FK, NOT NULL, etc.) still raise
+ ``IntegrityError`` — only the ``(thread_id, turn_id, role)`` collision
+ is silenced.
+ """
+ if not turn_id:
+ # Defensive: turn_id is always populated by the streaming path
+ # before this helper is called. If it isn't, we cannot be
+ # idempotent against the unique index — refuse to write rather
+ # than create a row the unique index can't dedupe.
+ logger.error(
+ "persist_user_turn called without a turn_id (chat_id=%s); skipping",
+ chat_id,
+ )
+ return None
+
+ t0 = time.perf_counter()
+ outcome = "failed"
+ resolved_id: int | None = None
+ try:
+ async with shielded_async_session() as ws:
+ # Re-attach the thread row so we can also bump updated_at
+ # in the same write — keeps the sidebar ordering accurate
+ # when a user fires off a turn but never reaches the
+ # legacy appendMessage.
+ thread = await ws.get(NewChatThread, chat_id)
+ author_uuid: UUID | None = None
+ if user_id:
+ try:
+ author_uuid = UUID(user_id)
+ except (TypeError, ValueError):
+ logger.warning(
+ "persist_user_turn: invalid user_id=%r, persisting as anonymous",
+ user_id,
+ )
+
+ content_payload = _build_user_content(
+ user_query, user_image_data_urls, mentioned_documents
+ )
+ insert_stmt = (
+ pg_insert(NewChatMessage)
+ .values(
+ thread_id=chat_id,
+ role=NewChatMessageRole.USER,
+ content=content_payload,
+ author_id=author_uuid,
+ turn_id=turn_id,
+ )
+ .on_conflict_do_nothing(
+ index_elements=["thread_id", "turn_id", "role"],
+ index_where=sa_text("turn_id IS NOT NULL"),
+ )
+ .returning(NewChatMessage.id)
+ )
+ inserted_id = (await ws.execute(insert_stmt)).scalar()
+
+ if inserted_id is None:
+ # Conflict on partial unique index — another writer
+ # (legacy FE appendMessage, redelivered request, etc.)
+ # already persisted this row. Look it up and reuse.
+ lookup = await ws.execute(
+ select(NewChatMessage.id).where(
+ NewChatMessage.thread_id == chat_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.USER,
+ )
+ )
+ existing_id = lookup.scalars().first()
+ if existing_id is None:
+ # Conflict reported but no row found — extremely
+ # unlikely (concurrent DELETE). Surface as failure.
+ logger.warning(
+ "persist_user_turn: conflict but no matching row "
+ "(chat_id=%s, turn_id=%s)",
+ chat_id,
+ turn_id,
+ )
+ outcome = "integrity_no_match"
+ return None
+ resolved_id = int(existing_id)
+ outcome = "race_recovered"
+ else:
+ resolved_id = int(inserted_id)
+ outcome = "inserted"
+ # Bump thread.updated_at only on a real insert — when
+ # we recovered an existing row the prior writer
+ # already touched the thread.
+ if thread is not None:
+ thread.updated_at = datetime.now(UTC)
+
+ await ws.commit()
+ return resolved_id
+ except Exception:
+ logger.exception(
+ "persist_user_turn failed (chat_id=%s, turn_id=%s)",
+ chat_id,
+ turn_id,
+ )
+ return None
+ finally:
+ _perf_log.info(
+ "[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
+ "message_id=%s query_len=%d images=%d mentioned_docs=%d "
+ "in %.3fs",
+ outcome,
+ chat_id,
+ turn_id,
+ resolved_id,
+ len(user_query or ""),
+ len(user_image_data_urls or ()),
+ len(mentioned_documents or ()),
+ time.perf_counter() - t0,
+ )
+
+
+async def persist_assistant_shell(
+ *,
+ chat_id: int,
+ user_id: str | None,
+ turn_id: str,
+) -> int | None:
+ """Pre-write an empty assistant row for the turn and return its id.
+
+ Inserts a placeholder ``new_chat_messages`` row (empty text content) so
+ the streaming layer has a stable ``message_id`` to correlate against
+ for the rest of the turn. ``finalize_assistant_turn`` overwrites the
+ ``content`` field at end-of-stream with the rich ``ContentPart[]``
+ snapshot produced by ``AssistantContentBuilder``.
+
+ Returns the row id on success, ``None`` on a genuine DB failure (caller
+ should abort the turn rather than stream into a void).
+
+ Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
+ index from migration 141: if a row already exists (resume retry, racing
+ legacy frontend, redelivered request, etc.) we look it up by
+ ``(thread_id, turn_id, role)`` and return its existing id. The streaming
+ layer is then free to UPDATE that row at finalize time.
+ """
+ if not turn_id:
+ logger.error(
+ "persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
+ chat_id,
+ )
+ return None
+
+ t0 = time.perf_counter()
+ outcome = "failed"
+ resolved_id: int | None = None
+ try:
+ async with shielded_async_session() as ws:
+ insert_stmt = (
+ pg_insert(NewChatMessage)
+ .values(
+ thread_id=chat_id,
+ role=NewChatMessageRole.ASSISTANT,
+ content=_EMPTY_SHELL_CONTENT,
+ author_id=None,
+ turn_id=turn_id,
+ )
+ .on_conflict_do_nothing(
+ index_elements=["thread_id", "turn_id", "role"],
+ index_where=sa_text("turn_id IS NOT NULL"),
+ )
+ .returning(NewChatMessage.id)
+ )
+ inserted_id = (await ws.execute(insert_stmt)).scalar()
+
+ if inserted_id is None:
+ # Conflict — another writer (legacy FE appendMessage,
+ # resume retry, redelivered request) wrote the
+ # (thread_id, turn_id, ASSISTANT) row first. Look it up
+ # so the streaming layer can UPDATE the same row at
+ # finalize time.
+ lookup = await ws.execute(
+ select(NewChatMessage.id).where(
+ NewChatMessage.thread_id == chat_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.ASSISTANT,
+ )
+ )
+ existing_id = lookup.scalars().first()
+ if existing_id is None:
+ logger.warning(
+ "persist_assistant_shell: conflict but no matching "
+ "(thread_id, turn_id, role) row found "
+ "(chat_id=%s, turn_id=%s)",
+ chat_id,
+ turn_id,
+ )
+ outcome = "integrity_no_match"
+ return None
+ resolved_id = int(existing_id)
+ outcome = "race_recovered"
+ else:
+ resolved_id = int(inserted_id)
+ outcome = "inserted"
+
+ await ws.commit()
+ return resolved_id
+ except Exception:
+ logger.exception(
+ "persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
+ chat_id,
+ turn_id,
+ )
+ return None
+ finally:
+ _perf_log.info(
+ "[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
+ "message_id=%s in %.3fs",
+ outcome,
+ chat_id,
+ turn_id,
+ resolved_id,
+ time.perf_counter() - t0,
+ )
+
+
+async def finalize_assistant_turn(
+ *,
+ message_id: int,
+ chat_id: int,
+ search_space_id: int,
+ user_id: str | None,
+ turn_id: str,
+ content: list[dict[str, Any]],
+ accumulator: TurnTokenAccumulator | None,
+) -> None:
+ """Finalize the assistant row and write its token_usage.
+
+ Two writes in a single shielded session:
+
+ 1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
+ WHERE id = :id`` — overwrites the placeholder ``persist_assistant_shell``
+ wrote with the full ``ContentPart[]`` snapshot produced server-side.
+ 2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
+ WHERE message_id IS NOT NULL DO NOTHING`` — uses the partial unique
+ index ``uq_token_usage_message_id`` from migration 142 to make the
+ insert idempotent against ``append_message``'s recovery branch
+ (which uses the same ON CONFLICT clause).
+
+ Substitutes the status-marker payload when ``content`` is empty
+ (pure tool-call turn that aborted before any output, or interrupt
+ before any event arrived). The status marker is preferable to a
+ blank text bubble because token accounting still runs and an ops
+ dashboard can flag the row.
+
+ Best-effort — never raises. The streaming ``finally`` calls this
+ from within ``anyio.CancelScope(shield=True)``; any raised exception
+ here would mask the real error that triggered the cleanup.
+ """
+ if not turn_id:
+ logger.error(
+ "finalize_assistant_turn called without turn_id "
+ "(chat_id=%s, message_id=%s); skipping",
+ chat_id,
+ message_id,
+ )
+ return
+ if not message_id:
+ logger.error(
+ "finalize_assistant_turn called without message_id "
+ "(chat_id=%s, turn_id=%s); skipping",
+ chat_id,
+ turn_id,
+ )
+ return
+
+ payload: list[dict[str, Any]]
+ is_status_marker = False
+ if content:
+ payload = content
+ else:
+ payload = _STATUS_NO_RESPONSE
+ is_status_marker = True
+
+ t0 = time.perf_counter()
+ outcome = "failed"
+ token_usage_attempted = bool(
+ accumulator is not None and accumulator.calls and user_id
+ )
+ try:
+ async with shielded_async_session() as ws:
+ assistant_row = await ws.get(NewChatMessage, message_id)
+ if assistant_row is None:
+ logger.warning(
+ "finalize_assistant_turn: row not found "
+ "(chat_id=%s, message_id=%s, turn_id=%s); skipping",
+ chat_id,
+ message_id,
+ turn_id,
+ )
+ outcome = "row_missing"
+ return
+
+ assistant_row.content = payload
+ assistant_row.updated_at = datetime.now(UTC)
+
+ # Token usage. ``record_token_usage`` (used elsewhere) does
+ # SELECT-then-INSERT in two statements which races with
+ # ``append_message``. Switch to a single INSERT ... ON
+ # CONFLICT DO NOTHING keyed on the migration-142 partial
+ # unique index so the loser silently drops its write at
+ # the DB level — exactly one row per ``message_id``,
+ # regardless of which session committed first.
+ if accumulator is not None and accumulator.calls and user_id:
+ try:
+ user_uuid = UUID(user_id)
+ except (TypeError, ValueError):
+ logger.warning(
+ "finalize_assistant_turn: invalid user_id=%r, "
+ "skipping token_usage row",
+ user_id,
+ )
+ else:
+ insert_stmt = (
+ pg_insert(TokenUsage)
+ .values(
+ usage_type="chat",
+ prompt_tokens=accumulator.total_prompt_tokens,
+ completion_tokens=accumulator.total_completion_tokens,
+ total_tokens=accumulator.grand_total,
+ cost_micros=accumulator.total_cost_micros,
+ model_breakdown=accumulator.per_message_summary(),
+ call_details={"calls": accumulator.serialized_calls()},
+ thread_id=chat_id,
+ message_id=message_id,
+ search_space_id=search_space_id,
+ user_id=user_uuid,
+ )
+ .on_conflict_do_nothing(
+ index_elements=["message_id"],
+ index_where=sa_text("message_id IS NOT NULL"),
+ )
+ )
+ await ws.execute(insert_stmt)
+
+ await ws.commit()
+ outcome = "ok"
+ except Exception:
+ logger.exception(
+ "finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
+ chat_id,
+ message_id,
+ turn_id,
+ )
+ finally:
+ _perf_log.info(
+ "[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
+ "turn_id=%s parts=%d status_marker=%s "
+ "token_usage_attempted=%s in %.3fs",
+ outcome,
+ chat_id,
+ message_id,
+ turn_id,
+ len(payload),
+ is_status_marker,
+ token_usage_attempted,
+ time.perf_counter() - t0,
+ )
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index 8288fb75a..3ba3912eb 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -19,12 +19,12 @@ import re
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
-from typing import Any
+from functools import partial
+from typing import Any, Literal
from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage
-from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@@ -33,6 +33,8 @@ from app.agents.multi_agent_chat import (
)
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
+from app.agents.new_chat.context import SurfSenseContextSchema
+from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import (
@@ -46,8 +48,11 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
-from app.agents.new_chat.errors import BusyError
-from app.agents.new_chat.middleware.busy_mutex import release_lock as _release_busy_lock
+from app.agents.new_chat.middleware.busy_mutex import (
+ end_turn,
+ get_cancel_state,
+ is_cancel_requested,
+)
from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state,
)
@@ -62,6 +67,12 @@ from app.db import (
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT
+from app.services.auto_model_pin_service import (
+ is_recently_healthy,
+ mark_healthy,
+ mark_runtime_cooldown,
+ resolve_or_get_pinned_llm_config_id,
+)
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@@ -76,6 +87,60 @@ _background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
+TURN_CANCELLING_INITIAL_DELAY_MS = 200
+TURN_CANCELLING_BACKOFF_FACTOR = 2
+TURN_CANCELLING_MAX_DELAY_MS = 1500
+
+
+def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
+ if attempt < 1:
+ attempt = 1
+ delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
+ TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
+ )
+ return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
+
+
+def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
+ """Return the first LangGraph interrupt payload across all snapshot tasks."""
+
+ def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
+ if isinstance(candidate, dict):
+ value = candidate.get("value", candidate)
+ return value if isinstance(value, dict) else None
+ value = getattr(candidate, "value", None)
+ if isinstance(value, dict):
+ return value
+ if isinstance(candidate, (list, tuple)):
+ for item in candidate:
+ extracted = _extract_interrupt_value(item)
+ if extracted is not None:
+ return extracted
+ return None
+
+ for task in getattr(state, "tasks", ()) or ():
+ try:
+ interrupts = getattr(task, "interrupts", ()) or ()
+ except (AttributeError, IndexError, TypeError):
+ interrupts = ()
+ if not interrupts:
+ extracted = _extract_interrupt_value(task)
+ if extracted is not None:
+ return extracted
+ continue
+ for interrupt_item in interrupts:
+ extracted = _extract_interrupt_value(interrupt_item)
+ if extracted is not None:
+ return extracted
+ try:
+ state_interrupts = getattr(state, "interrupts", ()) or ()
+ except (AttributeError, IndexError, TypeError):
+ state_interrupts = ()
+ extracted = _extract_interrupt_value(state_interrupts)
+ if extracted is not None:
+ return extracted
+ return None
+
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
@@ -253,6 +318,19 @@ class StreamResult:
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
+ # Pre-allocated assistant ``new_chat_messages.id`` for this turn,
+ # captured by ``persist_assistant_shell`` right after the user row is
+ # persisted. ``None`` for the legacy / anonymous code paths that don't
+ # opt in to server-side ``ContentPart[]`` projection.
+ assistant_message_id: int | None = None
+ # In-memory mirror of the FE's assistant-ui ``ContentPartsState``,
+ # populated by the lifecycle methods called from ``_stream_agent_events``
+ # at each ``streaming_service.format_*`` yield site. Snapshot in the
+ # streaming ``finally`` to produce the rich JSONB persisted by
+ # ``finalize_assistant_turn``. ``repr=False`` keeps the
+ # log-on-error path (``StreamResult`` is logged in some error
+ # branches) from dumping a potentially-large parts list.
+ content_builder: Any | None = field(default=None, repr=False)
def _safe_float(value: Any, default: float = 0.0) -> float:
@@ -285,20 +363,17 @@ def _tool_output_has_error(tool_output: Any) -> bool:
return False
-def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
+def _extract_resolved_file_path(
+ *, tool_name: str, tool_output: Any, tool_input: Any | None = None
+) -> str | None:
if isinstance(tool_output, dict):
path_value = tool_output.get("path")
if isinstance(path_value, str) and path_value.strip():
return path_value.strip()
- text = _tool_output_to_text(tool_output)
- if tool_name == "write_file":
- match = re.search(r"Updated file\s+(.+)$", text.strip())
- if match:
- return match.group(1).strip()
- if tool_name == "edit_file":
- match = re.search(r"in '([^']+)'", text)
- if match:
- return match.group(1).strip()
+ if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict):
+ file_path = tool_input.get("file_path")
+ if isinstance(file_path, str) and file_path.strip():
+ return file_path.strip()
return None
@@ -344,6 +419,273 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
)
+def _log_chat_stream_error(
+ *,
+ flow: Literal["new", "resume", "regenerate"],
+ error_kind: str,
+ error_code: str | None,
+ severity: Literal["info", "warn", "error"],
+ is_expected: bool,
+ request_id: str | None,
+ thread_id: int | None,
+ search_space_id: int | None,
+ user_id: str | None,
+ message: str,
+ extra: dict[str, Any] | None = None,
+) -> None:
+ payload: dict[str, Any] = {
+ "event": "chat_stream_error",
+ "flow": flow,
+ "error_kind": error_kind,
+ "error_code": error_code,
+ "severity": severity,
+ "is_expected": is_expected,
+ "request_id": request_id or "unknown",
+ "thread_id": thread_id,
+ "search_space_id": search_space_id,
+ "user_id": user_id,
+ "message": message,
+ }
+ if extra:
+ payload.update(extra)
+
+ logger = logging.getLogger(__name__)
+ rendered = json.dumps(payload, ensure_ascii=False)
+ if severity == "error":
+ logger.error("[chat_stream_error] %s", rendered)
+ elif severity == "warn":
+ logger.warning("[chat_stream_error] %s", rendered)
+ else:
+ logger.info("[chat_stream_error] %s", rendered)
+
+
+def _parse_error_payload(message: str) -> dict[str, Any] | None:
+ candidates = [message]
+ first_brace_idx = message.find("{")
+ if first_brace_idx >= 0:
+ candidates.append(message[first_brace_idx:])
+
+ for candidate in candidates:
+ try:
+ parsed = json.loads(candidate)
+ if isinstance(parsed, dict):
+ return parsed
+ except Exception:
+ continue
+ return None
+
+
+def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
+ if not isinstance(parsed, dict):
+ return None
+ candidates: list[Any] = [parsed.get("code")]
+ nested = parsed.get("error")
+ if isinstance(nested, dict):
+ candidates.append(nested.get("code"))
+ for value in candidates:
+ try:
+ if value is None:
+ continue
+ return int(value)
+ except Exception:
+ continue
+ return None
+
+
+def _is_provider_rate_limited(exc: BaseException) -> bool:
+ """Best-effort detection for provider-side runtime throttling.
+
+ Covers LiteLLM/OpenRouter shapes like:
+ - class name contains ``RateLimit``
+ - nested payload ``{"error": {"code": 429}}``
+ - nested payload ``{"error": {"type": "rate_limit_error"}}``
+ """
+ raw = str(exc)
+ lowered = raw.lower()
+ if "ratelimit" in type(exc).__name__.lower():
+ return True
+ parsed = _parse_error_payload(raw)
+ provider_code = _extract_provider_error_code(parsed)
+ if provider_code == 429:
+ return True
+
+ provider_error_type = ""
+ if parsed:
+ top_type = parsed.get("type")
+ if isinstance(top_type, str):
+ provider_error_type = top_type.lower()
+ nested = parsed.get("error")
+ if isinstance(nested, dict):
+ nested_type = nested.get("type")
+ if isinstance(nested_type, str):
+ provider_error_type = nested_type.lower()
+ if provider_error_type == "rate_limit_error":
+ return True
+
+ return (
+ "rate limited" in lowered
+ or "rate-limited" in lowered
+ or "temporarily rate-limited upstream" in lowered
+ )
+
+
+_PREFLIGHT_TIMEOUT_SEC: float = 2.5
+_PREFLIGHT_MAX_TOKENS: int = 1
+
+
+async def _preflight_llm(llm: Any) -> None:
+ """Issue a minimal completion to confirm the pinned model isn't 429'ing.
+
+ Used before agent build / planner / classifier / title-gen so a known-bad
+ free OpenRouter deployment is detected and repinned before it cascades
+ into multiple wasted internal calls. The probe is intentionally cheap:
+ one token, low timeout, tagged ``surfsense:internal`` so token tracking
+ and SSE pipelines treat it as overhead rather than user output.
+
+ Raises the original exception when the provider responds with a
+ rate-limit-shaped error so the caller can drive the cooldown/repin
+ branch via :func:`_is_provider_rate_limited`. Other transient failures
+ are swallowed — the caller continues to the normal stream path and the
+ in-stream recovery loop remains the safety net.
+ """
+ from litellm import acompletion
+
+ model = getattr(llm, "model", None)
+ if not model or model == "auto":
+ # Auto-mode router doesn't have a single deployment to ping; the
+ # router itself handles per-deployment rate-limit accounting.
+ return
+
+ try:
+ await acompletion(
+ model=model,
+ messages=[{"role": "user", "content": "ping"}],
+ api_key=getattr(llm, "api_key", None),
+ api_base=getattr(llm, "api_base", None),
+ max_tokens=_PREFLIGHT_MAX_TOKENS,
+ timeout=_PREFLIGHT_TIMEOUT_SEC,
+ stream=False,
+ metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
+ )
+ except Exception as exc:
+ if _is_provider_rate_limited(exc):
+ raise
+ logging.getLogger(__name__).debug(
+ "auto_pin_preflight non_rate_limit_error model=%s err=%s",
+ model,
+ exc,
+ )
+
+
+async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
+ """Wait for a discarded speculative agent build to release shared state.
+
+ Used by the parallel preflight + agent-build path. The speculative build
+ closes over the request-scoped ``AsyncSession`` (for the brief connector
+ discovery / tool-factory window before its CPU work moves into a worker
+ thread). If preflight reports a 429 we want to fall back to the original
+ repin → reload → rebuild path, but we MUST NOT touch ``session`` again
+ until any in-flight session work owned by the speculative build has
+ fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
+ concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
+ earlier in this PR (see ``connector_service`` parallel-gather revert).
+
+ We simply ``await`` the task and swallow any exception: in this path the
+ build's outcome is irrelevant — success populates the agent cache (a free
+ side effect), failure is discarded. The wasted CPU is acceptable since
+ 429 fallbacks are rare and the original sequential code also paid the
+ full build cost on the same path.
+ """
+ with contextlib.suppress(BaseException):
+ await task
+
+
+def _classify_stream_exception(
+ exc: Exception,
+ *,
+ flow_label: str,
+) -> tuple[
+ str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
+]:
+ raw = str(exc)
+ if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
+ busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
+ if busy_thread_id and is_cancel_requested(busy_thread_id):
+ cancel_state = get_cancel_state(busy_thread_id)
+ attempt = cancel_state[0] if cancel_state else 1
+ retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
+ retry_after_at = int(time.time() * 1000) + retry_after_ms
+ return (
+ "thread_busy",
+ "TURN_CANCELLING",
+ "info",
+ True,
+ "A previous response is still stopping. Please try again in a moment.",
+ {
+ "retry_after_ms": retry_after_ms,
+ "retry_after_at": retry_after_at,
+ },
+ )
+ return (
+ "thread_busy",
+ "THREAD_BUSY",
+ "warn",
+ True,
+ "Another response is still finishing for this thread. Please try again in a moment.",
+ None,
+ )
+
+ if _is_provider_rate_limited(exc):
+ return (
+ "rate_limited",
+ "RATE_LIMITED",
+ "warn",
+ True,
+ "This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
+ None,
+ )
+
+ return (
+ "server_error",
+ "SERVER_ERROR",
+ "error",
+ False,
+ f"Error during {flow_label}: {raw}",
+ None,
+ )
+
+
+def _emit_stream_terminal_error(
+ *,
+ streaming_service: VercelStreamingService,
+ flow: str,
+ request_id: str | None,
+ thread_id: int,
+ search_space_id: int,
+ user_id: str | None,
+ message: str,
+ error_kind: str = "server_error",
+ error_code: str = "SERVER_ERROR",
+ severity: Literal["info", "warn", "error"] = "error",
+ is_expected: bool = False,
+ extra: dict[str, Any] | None = None,
+) -> str:
+ _log_chat_stream_error(
+ flow=flow,
+ error_kind=error_kind,
+ error_code=error_code,
+ severity=severity,
+ is_expected=is_expected,
+ request_id=request_id,
+ thread_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=message,
+ extra=extra,
+ )
+ return streaming_service.format_error(message, error_code=error_code, extra=extra)
+
+
def _legacy_match_lc_id(
pending_tool_call_chunks: list[dict[str, Any]],
tool_name: str,
@@ -395,6 +737,8 @@ async def _stream_agent_events(
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
+ runtime_context: Any = None,
+ content_builder: Any | None = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@@ -411,6 +755,15 @@ async def _stream_agent_events(
initial_step_id: If set, the helper inherits an already-active thinking step.
initial_step_title: Title of the inherited thinking step.
initial_step_items: Items of the inherited thinking step.
+ content_builder: Optional ``AssistantContentBuilder``. When set, every
+ ``streaming_service.format_*`` yield site also drives the matching
+ builder lifecycle method (``on_text_*``, ``on_reasoning_*``,
+ ``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the
+ in-memory ``ContentPart[]`` projection stays in lockstep with what
+ the FE renders live. Pure in-memory accumulation — no DB I/O —
+ consumed by the streaming ``finally`` to produce the rich JSONB
+ persisted via ``finalize_assistant_turn``. ``None`` (the default)
+ is used by the anonymous / legacy code paths and is a no-op.
Yields:
SSE-formatted strings for each event.
@@ -451,6 +804,7 @@ async def _stream_agent_events(
# fallback path only and never re-pops a chunk we already streamed.
pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {}
+ file_path_by_run: dict[str, str] = {}
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
# is keyed by the chunk's ``index`` field — LangChain
@@ -474,12 +828,46 @@ async def _stream_agent_events(
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
def _emit_tool_output(call_id: str, output: Any) -> str:
+ # Drive the builder before formatting the SSE so the in-memory
+ # ContentPart[] mirror sees the result attached to the same
+ # card the FE will render. Builder method is a no-op when
+ # ``content_builder`` is None (anonymous / legacy paths).
+ if content_builder is not None:
+ content_builder.on_tool_output_available(
+ call_id, output, current_lc_tool_call_id["value"]
+ )
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=current_lc_tool_call_id["value"],
)
+ def _emit_thinking_step(
+ *,
+ step_id: str,
+ title: str,
+ status: str = "in_progress",
+ items: list[str] | None = None,
+ ) -> str:
+ """Format a thinking-step SSE event and notify the builder.
+
+ Single helper used at every ``format_thinking_step`` yield site
+ in this generator. Drives ``AssistantContentBuilder.on_thinking_step``
+ first so the FE-mirror state lands the update before the SSE
+ carrying the same data leaves the wire — order matches the FE
+ pipeline (``processSharedStreamEvent`` updates state, then
+ flushes). Builder call is a no-op when ``content_builder`` is
+ None (anonymous / legacy paths).
+ """
+ if content_builder is not None:
+ content_builder.on_thinking_step(step_id, title, status, items)
+ return streaming_service.format_thinking_step(
+ step_id=step_id,
+ title=title,
+ status=status,
+ items=items,
+ )
+
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
@@ -489,7 +877,7 @@ async def _stream_agent_events(
nonlocal last_active_step_id
if last_active_step_id and last_active_step_id not in completed_step_ids:
completed_step_ids.add(last_active_step_id)
- event = streaming_service.format_thinking_step(
+ event = _emit_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="completed",
@@ -499,7 +887,18 @@ async def _stream_agent_events(
return event
return None
- async for event in agent.astream_events(input_data, config=config, version="v2"):
+ # Per-invocation runtime context (Phase 1.5). When supplied,
+ # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids``
+ # from ``runtime.context`` instead of its constructor closure — the
+ # prerequisite that lets the compiled-agent cache (Phase 1) reuse a
+ # single graph across turns. Astream_events_kwargs stays empty when
+ # callers leave ``runtime_context`` as ``None`` to preserve the
+ # legacy code path bit-for-bit.
+ astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
+ if runtime_context is not None:
+ astream_kwargs["context"] = runtime_context
+
+ async for event in agent.astream_events(input_data, **astream_kwargs):
event_type = event.get("event", "")
if event_type == "on_chat_model_stream":
@@ -523,6 +922,8 @@ async def _stream_agent_events(
if parity_v2 and reasoning_delta:
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
+ if content_builder is not None:
+ content_builder.on_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is None:
completion_event = complete_current_step()
@@ -535,13 +936,21 @@ async def _stream_agent_events(
just_finished_tool = False
current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(current_reasoning_id)
+ if content_builder is not None:
+ content_builder.on_reasoning_start(current_reasoning_id)
yield streaming_service.format_reasoning_delta(
current_reasoning_id, reasoning_delta
)
+ if content_builder is not None:
+ content_builder.on_reasoning_delta(
+ current_reasoning_id, reasoning_delta
+ )
if text_delta:
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(current_reasoning_id)
+ if content_builder is not None:
+ content_builder.on_reasoning_end(current_reasoning_id)
current_reasoning_id = None
if current_text_id is None:
completion_event = complete_current_step()
@@ -554,8 +963,12 @@ async def _stream_agent_events(
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
+ if content_builder is not None:
+ content_builder.on_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
+ if content_builder is not None:
+ content_builder.on_text_delta(current_text_id, text_delta)
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
@@ -587,11 +1000,17 @@ async def _stream_agent_events(
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
+ if content_builder is not None:
+ content_builder.on_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
+ if content_builder is not None:
+ content_builder.on_reasoning_end(
+ current_reasoning_id
+ )
current_reasoning_id = None
index_to_meta[idx] = {
@@ -604,6 +1023,8 @@ async def _stream_agent_events(
name,
langchain_tool_call_id=lc_id,
)
+ if content_builder is not None:
+ content_builder.on_tool_input_start(ui_id, name, lc_id)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
@@ -619,6 +1040,10 @@ async def _stream_agent_events(
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
+ if content_builder is not None:
+ content_builder.on_tool_input_delta(
+ meta["ui_id"], args_chunk
+ )
else:
pending_tool_call_chunks.append(tcc)
@@ -629,9 +1054,15 @@ async def _stream_agent_events(
tool_input = event.get("data", {}).get("input", {})
if tool_name in ("write_file", "edit_file"):
result.write_attempted = True
+ if isinstance(tool_input, dict):
+ file_path = tool_input.get("file_path")
+ if isinstance(file_path, str) and file_path.strip() and run_id:
+ file_path_by_run[run_id] = file_path.strip()
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
+ if content_builder is not None:
+ content_builder.on_text_end(current_text_id)
current_text_id = None
if last_active_step_title != "Synthesizing response":
@@ -652,7 +1083,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Listing files"
last_active_step_items = [ls_path]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Listing files",
status="in_progress",
@@ -667,7 +1098,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
last_active_step_title = "Reading file"
last_active_step_items = [display_fp]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Reading file",
status="in_progress",
@@ -682,7 +1113,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
last_active_step_title = "Writing file"
last_active_step_items = [display_fp]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Writing file",
status="in_progress",
@@ -697,7 +1128,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
last_active_step_title = "Editing file"
last_active_step_items = [display_fp]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Editing file",
status="in_progress",
@@ -714,7 +1145,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Searching files"
last_active_step_items = [f"{pat} in {base_path}"]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Searching files",
status="in_progress",
@@ -734,7 +1165,7 @@ async def _stream_agent_events(
last_active_step_items = [
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Searching content",
status="in_progress",
@@ -749,7 +1180,7 @@ async def _stream_agent_events(
display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:]
last_active_step_title = "Deleting file"
last_active_step_items = [display_path] if display_path else []
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Deleting file",
status="in_progress",
@@ -766,7 +1197,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Deleting folder"
last_active_step_items = [display_path] if display_path else []
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Deleting folder",
status="in_progress",
@@ -783,7 +1214,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Creating folder"
last_active_step_items = [display_path] if display_path else []
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Creating folder",
status="in_progress",
@@ -806,7 +1237,7 @@ async def _stream_agent_events(
last_active_step_items = (
[f"{display_src} → {display_dst}"] if src or dst else []
)
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Moving file",
status="in_progress",
@@ -823,7 +1254,7 @@ async def _stream_agent_events(
if todo_count
else []
)
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Planning tasks",
status="in_progress",
@@ -838,7 +1269,7 @@ async def _stream_agent_events(
display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "")
last_active_step_title = "Saving document"
last_active_step_items = [display_title]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Saving document",
status="in_progress",
@@ -854,7 +1285,7 @@ async def _stream_agent_events(
last_active_step_items = [
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"
]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Generating image",
status="in_progress",
@@ -870,7 +1301,7 @@ async def _stream_agent_events(
last_active_step_items = [
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Scraping webpage",
status="in_progress",
@@ -893,7 +1324,7 @@ async def _stream_agent_events(
f"Content: {content_len:,} characters",
"Preparing audio generation...",
]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Generating podcast",
status="in_progress",
@@ -914,7 +1345,7 @@ async def _stream_agent_events(
f"Topic: {report_topic}",
"Analyzing source content...",
]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title=step_title,
status="in_progress",
@@ -929,7 +1360,7 @@ async def _stream_agent_events(
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
last_active_step_title = "Running command"
last_active_step_items = [f"$ {display_cmd}"]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title="Running command",
status="in_progress",
@@ -946,7 +1377,7 @@ async def _stream_agent_events(
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
last_active_step_items = []
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=tool_step_id,
title=last_active_step_title,
status="in_progress",
@@ -1007,6 +1438,10 @@ async def _stream_agent_events(
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
+ if content_builder is not None:
+ content_builder.on_tool_input_start(
+ tool_call_id, tool_name, langchain_tool_call_id
+ )
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
@@ -1029,12 +1464,20 @@ async def _stream_agent_events(
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
)
+ if content_builder is not None:
+ content_builder.on_tool_input_available(
+ tool_call_id,
+ tool_name,
+ _safe_input,
+ langchain_tool_call_id,
+ )
elif event_type == "on_tool_end":
active_tool_depth = max(0, active_tool_depth - 1)
run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
+ staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None
if tool_name == "update_memory":
called_update_memory = True
@@ -1100,70 +1543,70 @@ async def _stream_agent_events(
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Reading file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_file":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Writing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "edit_file":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Editing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "glob":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Searching files",
status="completed",
items=last_active_step_items,
)
elif tool_name == "grep":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Searching content",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rm":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Deleting file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rmdir":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Deleting folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "mkdir":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Creating folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "move_file":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Moving file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_todos":
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Planning tasks",
status="completed",
@@ -1180,7 +1623,7 @@ async def _stream_agent_events(
*last_active_step_items,
result_str[:80] if is_error else "Saved to knowledge base",
]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Saving document",
status="completed",
@@ -1199,7 +1642,7 @@ async def _stream_agent_events(
else "Generation failed"
)
completed_items = [*last_active_step_items, f"Error: {error_msg}"]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Generating image",
status="completed",
@@ -1223,7 +1666,7 @@ async def _stream_agent_events(
]
else:
completed_items = [*last_active_step_items, "Content extracted"]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Scraping webpage",
status="completed",
@@ -1240,10 +1683,10 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else "Podcast"
)
- if podcast_status == "processing":
+ if podcast_status in ("pending", "generating", "processing"):
completed_items = [
f"Title: {podcast_title}",
- "Audio generation started",
+ "Podcast generation started",
"Processing in background...",
]
elif podcast_status == "already_generating":
@@ -1252,7 +1695,7 @@ async def _stream_agent_events(
"Podcast already in progress",
"Please wait for it to complete",
]
- elif podcast_status == "error":
+ elif podcast_status in ("failed", "error"):
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
@@ -1262,9 +1705,14 @@ async def _stream_agent_events(
f"Title: {podcast_title}",
f"Error: {error_msg[:50]}",
]
+ elif podcast_status in ("ready", "success"):
+ completed_items = [
+ f"Title: {podcast_title}",
+ "Podcast ready",
+ ]
else:
completed_items = last_active_step_items
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Generating podcast",
status="completed",
@@ -1299,7 +1747,7 @@ async def _stream_agent_events(
]
else:
completed_items = last_active_step_items
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Generating video presentation",
status="completed",
@@ -1347,7 +1795,7 @@ async def _stream_agent_events(
else:
completed_items = last_active_step_items
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title=step_title,
status="completed",
@@ -1373,7 +1821,7 @@ async def _stream_agent_events(
]
else:
completed_items = [*last_active_step_items, "Finished"]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Running command",
status="completed",
@@ -1413,7 +1861,7 @@ async def _stream_agent_events(
completed_items.append(f"(+{len(file_names) - 4} more)")
else:
completed_items = ["No files found"]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title="Listing files",
status="completed",
@@ -1425,7 +1873,7 @@ async def _stream_agent_events(
fallback_title = (
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=original_step_id,
title=fallback_title,
status="completed",
@@ -1444,20 +1892,28 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else {"result": tool_output},
)
- if (
- isinstance(tool_output, dict)
- and tool_output.get("status") == "success"
+ if isinstance(tool_output, dict) and tool_output.get("status") in (
+ "pending",
+ "generating",
+ "processing",
+ ):
+ yield streaming_service.format_terminal_info(
+ f"Podcast queued: {tool_output.get('title', 'Podcast')}",
+ "success",
+ )
+ elif isinstance(tool_output, dict) and tool_output.get("status") in (
+ "ready",
+ "success",
):
yield streaming_service.format_terminal_info(
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
"success",
)
- else:
- error_msg = (
- tool_output.get("error", "Unknown error")
- if isinstance(tool_output, dict)
- else "Unknown error"
- )
+ elif isinstance(tool_output, dict) and tool_output.get("status") in (
+ "failed",
+ "error",
+ ):
+ error_msg = tool_output.get("error", "Unknown error")
yield streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}",
"error",
@@ -1548,6 +2004,9 @@ async def _stream_agent_events(
resolved_path = _extract_resolved_file_path(
tool_name=tool_name,
tool_output=tool_output,
+ tool_input={"file_path": staged_file_path}
+ if staged_file_path
+ else None,
)
result_text = _tool_output_to_text(tool_output)
if _tool_output_has_error(tool_output):
@@ -1754,7 +2213,7 @@ async def _stream_agent_events(
# Phase transitions: replace everything after topic
last_active_step_items = [*topic_items, message]
- yield streaming_service.format_thinking_step(
+ yield _emit_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="in_progress",
@@ -1796,10 +2255,14 @@ async def _stream_agent_events(
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
+ if content_builder is not None:
+ content_builder.on_text_end(current_text_id)
current_text_id = None
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
+ if content_builder is not None:
+ content_builder.on_text_end(current_text_id)
completion_event = complete_current_step()
if completion_event:
@@ -1884,8 +2347,14 @@ async def _stream_agent_events(
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
+ if content_builder is not None:
+ content_builder.on_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
+ if content_builder is not None:
+ content_builder.on_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
+ if content_builder is not None:
+ content_builder.on_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
@@ -1896,19 +2365,11 @@ async def _stream_agent_events(
result.agent_called_update_memory = called_update_memory
_log_file_contract("turn_outcome", result)
- snapshot_interrupts = getattr(state, "interrupts", ()) or ()
- interrupt_value = None
- if snapshot_interrupts:
- interrupt_value = snapshot_interrupts[0].value
- else:
- for task in state.tasks or []:
- if task.interrupts:
- interrupt_value = task.interrupts[0].value
- break
+ interrupt_value = _first_interrupt_value(state)
if interrupt_value is not None:
result.is_interrupted = True
result.interrupt_value = interrupt_value
- yield streaming_service.format_interrupt_request(interrupt_value)
+ yield streaming_service.format_interrupt_request(result.interrupt_value)
async def stream_new_chat(
@@ -1919,6 +2380,7 @@ async def stream_new_chat(
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
+ mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
@@ -1927,6 +2389,7 @@ async def stream_new_chat(
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
user_image_data_urls: list[str] | None = None,
+ flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
@@ -1974,14 +2437,26 @@ async def stream_new_chat(
accumulator = start_turn()
- # Premium quota tracking state
- _premium_reserved = 0
+ # Premium credit (USD micro-units) tracking state. Stores the
+ # amount reserved up front so we can release it on cancellation
+ # and finalize-debit the actual provider cost reported by LiteLLM.
+ _premium_reserved_micros = 0
_premium_request_id: str | None = None
# ``BusyError`` fires before the lock is acquired; the ``finally`` must
# not release the in-flight caller's lock.
_busy_error_raised = False
+ _emit_stream_error = partial(
+ _emit_stream_terminal_error,
+ streaming_service=streaming_service,
+ flow=flow,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
@@ -1989,87 +2464,280 @@ async def stream_new_chat(
await set_ai_responding(session, chat_id, UUID(user_id))
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None
+ requested_llm_config_id = llm_config_id
+
+ async def _load_llm_bundle(
+ config_id: int,
+ ) -> tuple[Any, AgentConfig | None, str | None]:
+ if config_id >= 0:
+ loaded_agent_config = await load_agent_config(
+ session=session,
+ config_id=config_id,
+ search_space_id=search_space_id,
+ )
+ if not loaded_agent_config:
+ return (
+ None,
+ None,
+ f"Failed to load NewLLMConfig with id {config_id}",
+ )
+ return (
+ create_chat_litellm_from_agent_config(loaded_agent_config),
+ loaded_agent_config,
+ None,
+ )
+
+ loaded_llm_config = load_global_llm_config_by_id(config_id)
+ if not loaded_llm_config:
+ return None, None, f"Failed to load LLM config with id {config_id}"
+ return (
+ create_chat_litellm_from_config(loaded_llm_config),
+ AgentConfig.from_yaml_config(loaded_llm_config),
+ None,
+ )
_t0 = time.perf_counter()
- if llm_config_id >= 0:
- # Positive ID: Load from NewLLMConfig database table
- agent_config = await load_agent_config(
- session=session,
- config_id=llm_config_id,
- search_space_id=search_space_id,
+ # Image-bearing turns force the Auto-pin resolver to filter the
+ # candidate pool to vision-capable cfgs (and force-repin a
+ # text-only existing pin). For explicit selections this flag is
+ # a no-op — the resolver returns the user's chosen id unchanged.
+ _requires_image_input = bool(user_image_data_urls)
+ try:
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=llm_config_id,
+ requires_image_input=_requires_image_input,
+ )
+ ).resolved_llm_config_id
+ except ValueError as pin_error:
+ # Auto-pin's "no vision-capable cfg" path raises a ValueError
+ # whose message we map to the friendly image-input SSE error
+ # so the user sees the same message regardless of whether
+ # the gate fired in Auto-mode or in the agent_config check
+ # below.
+ error_code = (
+ "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
+ if _requires_image_input and "vision-capable" in str(pin_error)
+ else "SERVER_ERROR"
)
- if not agent_config:
- yield streaming_service.format_error(
- f"Failed to load NewLLMConfig with id {llm_config_id}"
- )
- yield streaming_service.format_done()
- return
+ error_kind = (
+ "user_error"
+ if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
+ else "server_error"
+ )
+ yield _emit_stream_error(
+ message=str(pin_error),
+ error_kind=error_kind,
+ error_code=error_code,
+ )
+ yield streaming_service.format_done()
+ return
- # Create ChatLiteLLM from AgentConfig
- llm = create_chat_litellm_from_agent_config(agent_config)
- else:
- # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models)
- llm_config = load_global_llm_config_by_id(llm_config_id)
- if not llm_config:
- yield streaming_service.format_error(
- f"Failed to load LLM config with id {llm_config_id}"
- )
- yield streaming_service.format_done()
- return
-
- # Create ChatLiteLLM from global config dict
- llm = create_chat_litellm_from_config(llm_config)
- agent_config = AgentConfig.from_yaml_config(llm_config)
+ llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
+ if llm_load_error:
+ yield _emit_stream_error(
+ message=llm_load_error,
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
_perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0,
llm_config_id,
)
- # Premium quota reservation — applies to explicitly premium configs
- # AND Auto mode (which may route to premium models).
+ # Capability safety net: a turn carrying user-uploaded images
+ # cannot be routed to a chat config that LiteLLM's authoritative
+ # model map *explicitly* marks as text-only (``supports_vision``
+ # set to False). The check is intentionally narrow — it only
+ # fires when LiteLLM is *certain* the model can't accept image
+ # input. Unknown / unmapped / vision-capable models pass
+ # through. Without this guard a known-text-only model would 404
+ # at the provider with ``"No endpoints found that support image
+ # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
+ # failing here lets us return a friendly message that tells the
+ # user what to change.
+ if user_image_data_urls and agent_config is not None:
+ from app.services.provider_capabilities import (
+ is_known_text_only_chat_model,
+ )
+
+ agent_litellm_params = agent_config.litellm_params or {}
+ agent_base_model = (
+ agent_litellm_params.get("base_model")
+ if isinstance(agent_litellm_params, dict)
+ else None
+ )
+ if is_known_text_only_chat_model(
+ provider=agent_config.provider,
+ model_name=agent_config.model_name,
+ base_model=agent_base_model,
+ custom_provider=agent_config.custom_provider,
+ ):
+ model_label = (
+ agent_config.config_name or agent_config.model_name or "model"
+ )
+ yield _emit_stream_error(
+ message=(
+ f"The selected model ({model_label}) does not support "
+ "image input. Switch to a vision-capable model "
+ "(e.g. GPT-4o, Claude, Gemini) or remove the image "
+ "attachment and try again."
+ ),
+ error_kind="user_error",
+ error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
+ )
+ yield streaming_service.format_done()
+ return
+
+ # Premium quota reservation for pinned premium model only.
_needs_premium_quota = (
- agent_config is not None
- and user_id
- and (agent_config.is_premium or agent_config.is_auto_mode)
+ agent_config is not None and user_id and agent_config.is_premium
)
if _needs_premium_quota:
import uuid as _uuid
- from app.config import config as _app_config
- from app.services.token_quota_service import TokenQuotaService
+ from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+ )
_premium_request_id = _uuid.uuid4().hex[:16]
- reserve_amount = min(
- agent_config.quota_reserve_tokens
- or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
- _app_config.QUOTA_MAX_RESERVE_PER_CALL,
+ _agent_litellm_params = agent_config.litellm_params or {}
+ _agent_base_model = (
+ _agent_litellm_params.get("base_model") or agent_config.model_name or ""
+ )
+ reserve_amount_micros = estimate_call_reserve_micros(
+ base_model=_agent_base_model,
+ quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
- reserve_tokens=reserve_amount,
+ reserve_micros=reserve_amount_micros,
)
- _premium_reserved = reserve_amount
+ _premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
- if agent_config.is_premium:
- yield streaming_service.format_error(
- "Premium token quota exceeded. Please purchase more tokens to continue using premium models."
+ if requested_llm_config_id == 0:
+ try:
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=0,
+ force_repin_free=True,
+ requires_image_input=_requires_image_input,
+ )
+ ).resolved_llm_config_id
+ except ValueError as pin_error:
+ yield _emit_stream_error(
+ message=str(pin_error),
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(
+ llm_config_id
+ )
+ if llm_load_error:
+ yield _emit_stream_error(
+ message=llm_load_error,
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+ _premium_request_id = None
+ _premium_reserved_micros = 0
+ _log_chat_stream_error(
+ flow=flow,
+ error_kind="premium_quota_exhausted",
+ error_code="PREMIUM_QUOTA_EXHAUSTED",
+ severity="info",
+ is_expected=True,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=(
+ "Premium quota exhausted on pinned model; auto-fallback switched to a free model"
+ ),
+ extra={
+ "fallback_config_id": llm_config_id,
+ "auto_fallback": True,
+ },
+ )
+ else:
+ yield _emit_stream_error(
+ message=(
+ "Buy more tokens to continue with this model, or switch to a free model"
+ ),
+ error_kind="premium_quota_exhausted",
+ error_code="PREMIUM_QUOTA_EXHAUSTED",
+ severity="info",
+ is_expected=True,
+ extra={
+ "resolved_config_id": llm_config_id,
+ "auto_fallback": False,
+ },
)
yield streaming_service.format_done()
return
- # Auto mode: quota exhausted but we can still proceed
- # (the router may pick a free model). Reset reservation.
- _premium_request_id = None
- _premium_reserved = 0
if not llm:
- yield streaming_service.format_error("Failed to create LLM instance")
+ yield _emit_stream_error(
+ message="Failed to create LLM instance",
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
yield streaming_service.format_done()
return
+ # Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
+ # (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
+ # whose health hasn't already been confirmed within the TTL window.
+ # Detecting a 429 here lets us repin BEFORE the planner/classifier/
+ # title-generation LLM calls fan out and each independently hit the
+ # same upstream rate limit.
+ #
+ # PERF: preflight is a network round-trip to the LLM provider (~1-5s)
+ # and is independent of the agent build (CPU-bound, ~5-7s). They used
+ # to run sequentially → ``preflight + build`` on cold cache = 11.5s.
+ # We now kick off preflight as a background task FIRST, then run the
+ # synchronous setup work and the agent build in parallel. In the
+ # success path (the common case) total wall time drops to roughly
+ # ``max(preflight, build)`` — the preflight finishes during the
+ # agent compile and we just consume its result. In the rare 429
+ # path the speculative build is awaited to completion (so its
+ # session usage is fully released) via
+ # :func:`_settle_speculative_agent_build`, then discarded, and
+ # we fall back to the original repin-and-rebuild flow.
+ preflight_needed = (
+ requested_llm_config_id == 0
+ and llm_config_id < 0
+ and not is_recently_healthy(llm_config_id)
+ )
+ preflight_task: asyncio.Task[None] | None = None
+ _t_preflight = 0.0
+ if preflight_needed:
+ _t_preflight = time.perf_counter()
+ preflight_task = asyncio.create_task(
+ _preflight_llm(llm),
+ name=f"auto_pin_preflight:{llm_config_id}",
+ )
+
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
@@ -2098,24 +2766,17 @@ async def stream_new_chat(
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
_t0 = time.perf_counter()
- if use_multi_agent:
- agent = await create_registry_deep_agent(
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- mentioned_document_ids=mentioned_document_ids,
- disabled_tools=disabled_tools,
- )
- else:
- agent = await create_surfsense_deep_agent(
+ agent_factory = (
+ create_registry_deep_agent
+ if use_multi_agent
+ else create_surfsense_deep_agent
+ )
+ # Speculative agent build — runs in parallel with the preflight
+ # task (if any). Built with the *current* ``llm`` / ``agent_config``;
+ # if preflight reports 429 we will discard this future and rebuild
+ # against the freshly pinned config below.
+ agent_build_task = asyncio.create_task(
+ agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=session,
@@ -2129,7 +2790,116 @@ async def stream_new_chat(
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
- )
+ ),
+ name="agent_build:stream_new_chat",
+ )
+
+ agent: Any = None
+ if preflight_task is not None:
+ try:
+ await preflight_task
+ mark_healthy(llm_config_id)
+ _perf_log.info(
+ "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
+ llm_config_id,
+ time.perf_counter() - _t_preflight,
+ )
+ except Exception as preflight_exc:
+ # Both branches below need the session: the non-429 path
+ # may unwind via cleanup that uses ``session``, and the
+ # 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
+ # against it. Wait for the speculative build to release its
+ # session usage before we proceed.
+ await _settle_speculative_agent_build(agent_build_task)
+ if not _is_provider_rate_limited(preflight_exc):
+ raise
+ # 429: speculative agent is discarded; run the original
+ # repin → reload → rebuild path against the freshly
+ # pinned config.
+ previous_config_id = llm_config_id
+ mark_runtime_cooldown(
+ previous_config_id, reason="preflight_rate_limited"
+ )
+ try:
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=0,
+ exclude_config_ids={previous_config_id},
+ requires_image_input=_requires_image_input,
+ )
+ ).resolved_llm_config_id
+ except ValueError as pin_error:
+ yield _emit_stream_error(
+ message=str(pin_error),
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(
+ llm_config_id
+ )
+ if llm_load_error or not llm:
+ yield _emit_stream_error(
+ message=llm_load_error or "Failed to create LLM instance",
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+ # Trust the freshly-resolved cfg for the remainder of this
+ # turn rather than recursing into another preflight; the
+ # in-stream 429 recovery loop is still in place as the
+ # safety net if even this fallback hits an upstream cap.
+ mark_healthy(llm_config_id)
+ _log_chat_stream_error(
+ flow=flow,
+ error_kind="rate_limited",
+ error_code="RATE_LIMITED",
+ severity="info",
+ is_expected=True,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=(
+ "Auto-pinned model failed preflight; switched to another "
+ "eligible model and continuing."
+ ),
+ extra={
+ "auto_runtime_recover": True,
+ "preflight": True,
+ "previous_config_id": previous_config_id,
+ "fallback_config_id": llm_config_id,
+ },
+ )
+ # Rebuild against the new llm/agent_config. Sequential
+ # here because we no longer have anything to overlap with.
+ agent = await agent_factory(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ disabled_tools=disabled_tools,
+ mentioned_document_ids=mentioned_document_ids,
+ filesystem_selection=filesystem_selection,
+ )
+
+ if agent is None:
+ # Either no preflight was needed, or preflight succeeded —
+ # in both cases the speculative build is the agent we want.
+ agent = await agent_build_task
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
@@ -2301,6 +3071,97 @@ async def stream_new_chat(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
+ yield streaming_service.format_data("turn-status", {"status": "busy"})
+
+ # Persist the user-side row for this turn before any expensive
+ # work runs. Closes the "ghost-thread" abuse vector
+ # (authenticated client hits POST /new_chat then never calls
+ # /messages — empty new_chat_messages, free LLM completion).
+ # Idempotent against the unique index in migration 141 so the
+ # legacy frontend appendMessage call is a no-op on the second
+ # writer. Hard failure aborts the turn so we never produce a
+ # title or assistant row that isn't anchored to a persisted
+ # user message.
+ from app.tasks.chat.content_builder import AssistantContentBuilder
+ from app.tasks.chat.persistence import (
+ persist_assistant_shell,
+ persist_user_turn,
+ )
+
+ user_message_id = await persist_user_turn(
+ chat_id=chat_id,
+ user_id=user_id,
+ turn_id=stream_result.turn_id,
+ user_query=user_query,
+ user_image_data_urls=user_image_data_urls,
+ mentioned_documents=mentioned_documents,
+ )
+ if user_message_id is None:
+ yield _emit_stream_error(
+ message=(
+ "We couldn't save your message. Please try again in a moment."
+ ),
+ error_kind="server_error",
+ error_code="MESSAGE_PERSIST_FAILED",
+ )
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
+ yield streaming_service.format_finish_step()
+ yield streaming_service.format_finish()
+ yield streaming_service.format_done()
+ return
+
+ # Emit canonical user message id BEFORE any LLM streaming so the
+ # FE can rename its optimistic ``msg-user-XXX`` placeholder to
+ # ``msg-{user_message_id}`` and unlock features gated on a real
+ # DB id (comments, edit-from-this-message). See B4 in
+ # ``sse-based_message_id_handshake`` plan.
+ yield streaming_service.format_data(
+ "user-message-id",
+ {"message_id": user_message_id, "turn_id": stream_result.turn_id},
+ )
+
+ # Pre-write the assistant row for this turn so we have a stable
+ # ``message_id`` to anchor mid-stream metadata (token_usage,
+ # future agent_action_log.message_id correlation) and a
+ # write-once UPDATE target at finalize time. Idempotent against
+ # the (thread_id, turn_id, ASSISTANT) partial unique index from
+ # migration 141 — if the legacy frontend appendMessage races
+ # this, we recover the existing row's id.
+ assistant_message_id = await persist_assistant_shell(
+ chat_id=chat_id,
+ user_id=user_id,
+ turn_id=stream_result.turn_id,
+ )
+ if assistant_message_id is None:
+ # Genuine DB failure — abort the turn rather than stream
+ # into a void. The user row is already persisted so the
+ # legacy "ghost-thread" gate isn't reopened.
+ yield _emit_stream_error(
+ message=(
+ "We couldn't initialize the assistant message. Please try again."
+ ),
+ error_kind="server_error",
+ error_code="MESSAGE_PERSIST_FAILED",
+ )
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
+ yield streaming_service.format_finish_step()
+ yield streaming_service.format_finish()
+ yield streaming_service.format_done()
+ return
+
+ # Emit canonical assistant message id BEFORE any LLM streaming
+ # so the FE can rename its optimistic ``msg-assistant-XXX``
+ # placeholder to ``msg-{assistant_message_id}`` and bind
+ # ``tokenUsageStore`` / ``pendingInterrupt`` to the real id
+ # immediately. See B4 in ``sse-based_message_id_handshake``
+ # plan.
+ yield streaming_service.format_data(
+ "assistant-message-id",
+ {"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
+ )
+
+ stream_result.assistant_message_id = assistant_message_id
+ stream_result.content_builder = AssistantContentBuilder()
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
@@ -2334,6 +3195,15 @@ async def stream_new_chat(
initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
initial_step_id = "thinking-1"
+ # Drive the builder for this initial thinking step too — the
+ # ``_emit_thinking_step`` helper lives inside ``_stream_agent_events``
+ # so it isn't in scope here, but the FE folds this step into
+ # the same singleton ``data-thinking-steps`` part as everything
+ # the agent stream emits later. Mirror that fold server-side.
+ if stream_result.content_builder is not None:
+ stream_result.content_builder.on_thinking_step(
+ initial_step_id, initial_title, "in_progress", initial_items
+ )
yield streaming_service.format_thinking_step(
step_id=initial_step_id,
title=initial_title,
@@ -2350,16 +3220,34 @@ async def stream_new_chat(
# Check if this is the first assistant response so we can generate
# a title in parallel with the agent stream (better UX than waiting
# until after the full response).
- assistant_count_result = await session.execute(
- select(func.count(NewChatMessage.id)).filter(
+ # Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because
+ # this is now a hot path executed on every turn, and COUNT scales
+ # with thread length (server-side persistence can grow rows
+ # quickly under power users).
+ #
+ # IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already
+ # inserted THIS turn's assistant row. We must therefore exclude
+ # it from the probe — otherwise the gate fires on every turn
+ # except the very first, and title generation never runs for new
+ # threads. Excluding by primary key (``id != assistant_message_id``)
+ # is bulletproof regardless of ``turn_id`` shape (legacy NULLs,
+ # resume turns, etc.).
+ first_assistant_probe = await session.execute(
+ select(NewChatMessage.id)
+ .filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
+ NewChatMessage.id != assistant_message_id,
)
+ .limit(1)
)
- is_first_response = (assistant_count_result.scalar() or 0) == 0
+ is_first_response = first_assistant_probe.scalars().first() is None
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
- if is_first_response:
+ # Gate title generation on a persisted user message so a stream
+ # that fails before persistence (we abort above) can never leave
+ # behind a thread with a generated title and no anchoring rows.
+ if is_first_response and user_message_id is not None:
async def _generate_title() -> tuple[str | None, dict | None]:
"""Generate a short title via litellm.acompletion.
@@ -2375,6 +3263,7 @@ async def stream_new_chat(
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
+ from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
_turn_accumulator.set(None)
@@ -2395,11 +3284,32 @@ async def stream_new_chat(
model="auto", messages=messages
)
else:
+ # Apply the same ``api_base`` cascade chat / vision /
+ # image-gen call sites use so we never inherit
+ # ``litellm.api_base`` (commonly set by
+ # ``AZURE_OPENAI_ENDPOINT``) when the chat config
+ # itself ships an empty ``api_base``. Without this
+ # the title-gen on an OpenRouter chat config would
+ # 404 against the inherited Azure endpoint — see
+ # ``provider_api_base`` docstring for the same
+ # bug repro on the image-gen / vision paths.
+ raw_model = getattr(llm, "model", "") or ""
+ provider_prefix = (
+ raw_model.split("/", 1)[0] if "/" in raw_model else None
+ )
+ provider_value = (
+ agent_config.provider if agent_config is not None else None
+ )
+ title_api_base = resolve_api_base(
+ provider=provider_value,
+ provider_prefix=provider_prefix,
+ config_api_base=getattr(llm, "api_base", None),
+ )
response = await acompletion(
- model=llm.model,
+ model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
- api_base=getattr(llm, "api_base", None),
+ api_base=title_api_base,
)
usage_info = None
@@ -2433,56 +3343,172 @@ async def stream_new_chat(
title_emitted = False
+ # Build the per-invocation runtime context (Phase 1.5).
+ # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware``
+ # via ``runtime.context.mentioned_document_ids`` instead of its
+ # ``__init__`` closure — that way the same compiled-agent instance
+ # can serve multiple turns with different mention lists.
+ runtime_context = SurfSenseContextSchema(
+ search_space_id=search_space_id,
+ mentioned_document_ids=list(mentioned_document_ids or []),
+ request_id=request_id,
+ turn_id=stream_result.turn_id,
+ )
+
_t_stream_start = time.perf_counter()
_first_event_logged = False
- async for sse in _stream_agent_events(
- agent=agent,
- config=config,
- input_data=input_state,
- streaming_service=streaming_service,
- result=stream_result,
- step_prefix="thinking",
- initial_step_id=initial_step_id,
- initial_step_title=initial_title,
- initial_step_items=initial_items,
- fallback_commit_search_space_id=search_space_id,
- fallback_commit_created_by_id=user_id,
- fallback_commit_filesystem_mode=(
- filesystem_selection.mode
- if filesystem_selection
- else FilesystemMode.CLOUD
- ),
- fallback_commit_thread_id=chat_id,
- ):
- if not _first_event_logged:
- _perf_log.info(
- "[stream_new_chat] First agent event in %.3fs (time since stream start), "
- "%.3fs (total since request start) (chat_id=%s)",
- time.perf_counter() - _t_stream_start,
- time.perf_counter() - _t_total,
- chat_id,
- )
- _first_event_logged = True
- yield sse
-
- # Inject title update mid-stream as soon as the background task finishes
- if title_task is not None and title_task.done() and not title_emitted:
- generated_title, title_usage = title_task.result()
- if title_usage:
- accumulator.add(**title_usage)
- if generated_title:
- async with shielded_async_session() as title_session:
- title_thread_result = await title_session.execute(
- select(NewChatThread).filter(NewChatThread.id == chat_id)
+ runtime_rate_limit_recovered = False
+ while True:
+ try:
+ async for sse in _stream_agent_events(
+ agent=agent,
+ config=config,
+ input_data=input_state,
+ streaming_service=streaming_service,
+ result=stream_result,
+ step_prefix="thinking",
+ initial_step_id=initial_step_id,
+ initial_step_title=initial_title,
+ initial_step_items=initial_items,
+ fallback_commit_search_space_id=search_space_id,
+ fallback_commit_created_by_id=user_id,
+ fallback_commit_filesystem_mode=(
+ filesystem_selection.mode
+ if filesystem_selection
+ else FilesystemMode.CLOUD
+ ),
+ fallback_commit_thread_id=chat_id,
+ runtime_context=runtime_context,
+ content_builder=stream_result.content_builder,
+ ):
+ if not _first_event_logged:
+ _perf_log.info(
+ "[stream_new_chat] First agent event in %.3fs (time since stream start), "
+ "%.3fs (total since request start) (chat_id=%s)",
+ time.perf_counter() - _t_stream_start,
+ time.perf_counter() - _t_total,
+ chat_id,
)
- title_thread = title_thread_result.scalars().first()
- if title_thread:
- title_thread.title = generated_title
- await title_session.commit()
- yield streaming_service.format_thread_title_update(
- chat_id, generated_title
+ _first_event_logged = True
+ yield sse
+
+ # Inject title update mid-stream as soon as the background
+ # task finishes.
+ if (
+ title_task is not None
+ and title_task.done()
+ and not title_emitted
+ ):
+ generated_title, title_usage = title_task.result()
+ if title_usage:
+ accumulator.add(**title_usage)
+ if generated_title:
+ async with shielded_async_session() as title_session:
+ title_thread_result = await title_session.execute(
+ select(NewChatThread).filter(
+ NewChatThread.id == chat_id
+ )
+ )
+ title_thread = title_thread_result.scalars().first()
+ if title_thread:
+ title_thread.title = generated_title
+ await title_session.commit()
+ yield streaming_service.format_thread_title_update(
+ chat_id, generated_title
+ )
+ title_emitted = True
+ break
+ except Exception as stream_exc:
+ can_runtime_recover = (
+ not runtime_rate_limit_recovered
+ and requested_llm_config_id == 0
+ and llm_config_id < 0
+ and not _first_event_logged
+ and _is_provider_rate_limited(stream_exc)
+ )
+ if not can_runtime_recover:
+ raise
+
+ runtime_rate_limit_recovered = True
+ previous_config_id = llm_config_id
+ # The failed attempt may still hold the per-thread busy mutex
+ # (middleware teardown can lag behind raised provider errors).
+ # Force release before we retry within the same request.
+ end_turn(str(chat_id))
+ mark_runtime_cooldown(
+ previous_config_id,
+ reason="provider_rate_limited",
+ )
+
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=0,
+ exclude_config_ids={previous_config_id},
+ requires_image_input=_requires_image_input,
)
- title_emitted = True
+ ).resolved_llm_config_id
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(
+ llm_config_id
+ )
+ if llm_load_error:
+ raise stream_exc
+
+ # Title generation uses the initial llm object. After a runtime
+ # repin we keep the stream focused on response recovery and skip
+ # title generation for this turn.
+ if title_task is not None and not title_task.done():
+ title_task.cancel()
+ title_task = None
+
+ _t0 = time.perf_counter()
+ agent = await create_surfsense_deep_agent(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ disabled_tools=disabled_tools,
+ mentioned_document_ids=mentioned_document_ids,
+ filesystem_selection=filesystem_selection,
+ )
+ _perf_log.info(
+ "[stream_new_chat] Runtime rate-limit recovery repinned "
+ "config_id=%s -> %s and rebuilt agent in %.3fs",
+ previous_config_id,
+ llm_config_id,
+ time.perf_counter() - _t0,
+ )
+ _log_chat_stream_error(
+ flow=flow,
+ error_kind="rate_limited",
+ error_code="RATE_LIMITED",
+ severity="info",
+ is_expected=True,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=(
+ "Auto-pinned model hit runtime rate limit; switched to "
+ "another eligible model and retried."
+ ),
+ extra={
+ "auto_runtime_recover": True,
+ "previous_config_id": previous_config_id,
+ "fallback_config_id": llm_config_id,
+ },
+ )
+ continue
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
@@ -2497,9 +3523,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
+ "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -2510,6 +3537,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -2537,29 +3565,25 @@ async def stream_new_chat(
chat_id, generated_title
)
- # Finalize premium quota with actual tokens.
- # For Auto mode, only count tokens from calls that used premium models.
+ # Finalize premium credit debit with the actual provider cost
+ # reported by LiteLLM, summed across every call in the turn.
+ # Mirrors the pre-cost behaviour of "premium turn → all calls
+ # count" so free sub-agent calls during a premium turn still
+ # contribute to the bill (they're $0 in practice anyway).
if _premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
- if agent_config and agent_config.is_auto_mode:
- from app.services.llm_router_service import LLMRouterService
-
- actual_premium_tokens = LLMRouterService.compute_premium_tokens(
- accumulator.calls
- )
- else:
- actual_premium_tokens = accumulator.grand_total
-
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
- actual_tokens=actual_premium_tokens,
- reserved_tokens=_premium_reserved,
+ actual_micros=accumulator.total_cost_micros,
+ reserved_micros=_premium_reserved_micros,
)
+ _premium_request_id = None
+ _premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
@@ -2569,9 +3593,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] normal new_chat: calls=%d total=%d summary=%s",
+ "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -2582,6 +3607,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -2617,13 +3643,7 @@ async def stream_new_chat(
task.add_done_callback(_background_tasks.discard)
# Finish the step and message
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
-
- except BusyError as e:
- _busy_error_raised = True
- yield streaming_service.format_error(str(e))
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@@ -2632,12 +3652,41 @@ async def stream_new_chat(
# Handle any errors
import traceback
+ # ``BusyError`` fires before the agent acquires the lock; the
+ # cleanup path must skip lock release to avoid freeing the
+ # in-flight caller's lock. Classification is handled below.
+ if isinstance(e, BusyError):
+ _busy_error_raised = True
+
+ (
+ error_kind,
+ error_code,
+ severity,
+ is_expected,
+ user_message,
+ error_extra,
+ ) = _classify_stream_exception(e, flow_label="chat")
error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
+ if error_code == "TURN_CANCELLING":
+ status_payload: dict[str, Any] = {"status": "cancelling"}
+ if error_extra:
+ status_payload.update(error_extra)
+ yield streaming_service.format_data("turn-status", status_payload)
+ else:
+ yield streaming_service.format_data("turn-status", {"status": "busy"})
- yield streaming_service.format_error(error_message)
+ yield _emit_stream_error(
+ message=user_message,
+ error_kind=error_kind,
+ error_code=error_code,
+ severity=severity,
+ is_expected=is_expected,
+ extra=error_extra,
+ )
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@@ -2653,8 +3702,12 @@ async def stream_new_chat(
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
+ # Authoritative fallback cleanup for lock/cancel state. Middleware
+ # teardown can be skipped on some client-abort paths.
+ end_turn(str(chat_id))
+
# Release premium reservation if not finalized
- if _premium_request_id and _premium_reserved > 0 and user_id:
+ if _premium_request_id and _premium_reserved_micros > 0 and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -2662,9 +3715,9 @@ async def stream_new_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
- reserved_tokens=_premium_reserved,
+ reserved_micros=_premium_reserved_micros,
)
- _premium_reserved = 0
+ _premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id
@@ -2688,6 +3741,81 @@ async def stream_new_chat(
with contextlib.suppress(Exception):
await session.close()
+ # Server-side assistant-message + token_usage finalization.
+ # Runs after the main session has been closed (uses its own
+ # shielded session) so we don't fight the same DB connection.
+ # Idempotent against the legacy frontend appendMessage:
+ # * the assistant row was already INSERTed by
+ # ``persist_assistant_shell`` above, so this just UPDATEs
+ # it with the rich ContentPart[] from the builder.
+ # * token_usage uses INSERT ... ON CONFLICT DO NOTHING
+ # against migration 142's partial unique index, so a
+ # racing append_message recovery branch can never
+ # double-write.
+ # ``mark_interrupted`` closes any open text/reasoning blocks
+ # and flips running tool-calls (no result) to state=aborted
+ # so the persisted JSONB reflects a coherent end-state even
+ # on client disconnect.
+ # Never raises (best-effort, logs only).
+ if (
+ stream_result
+ and stream_result.turn_id
+ and stream_result.assistant_message_id
+ ):
+ from app.tasks.chat.persistence import finalize_assistant_turn
+
+ builder_stats: dict[str, int] | None = None
+ if stream_result.content_builder is not None:
+ stream_result.content_builder.mark_interrupted()
+ # Snapshot stats BEFORE deepcopy in ``snapshot()`` so
+ # the perf log records the actual finalised payload
+ # (post-mark_interrupted), not the live-mutating
+ # builder state.
+ builder_stats = stream_result.content_builder.stats()
+ content_payload = stream_result.content_builder.snapshot()
+ else:
+ # Defensive fallback — we always set the builder
+ # alongside ``assistant_message_id`` above, so this
+ # branch only fires if a future refactor ever
+ # decouples them. Persist whatever accumulated
+ # text we captured so the row at least renders.
+ content_payload = [
+ {
+ "type": "text",
+ "text": stream_result.accumulated_text or "",
+ }
+ ]
+
+ if builder_stats is not None:
+ _perf_log.info(
+ "[stream_new_chat] finalize_payload chat_id=%s "
+ "message_id=%s parts=%d bytes=%d text=%d "
+ "reasoning=%d tool_calls=%d "
+ "tool_calls_completed=%d tool_calls_aborted=%d "
+ "thinking_step_parts=%d step_separators=%d",
+ chat_id,
+ stream_result.assistant_message_id,
+ builder_stats["parts"],
+ builder_stats["bytes"],
+ builder_stats["text"],
+ builder_stats["reasoning"],
+ builder_stats["tool_calls"],
+ builder_stats["tool_calls_completed"],
+ builder_stats["tool_calls_aborted"],
+ builder_stats["thinking_step_parts"],
+ builder_stats["step_separators"],
+ )
+
+ await finalize_assistant_turn(
+ message_id=stream_result.assistant_message_id,
+ chat_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ turn_id=stream_result.turn_id,
+ content=content_payload,
+ accumulator=accumulator,
+ )
+
# Persist any sandbox-produced files to local storage so they
# remain downloadable after the Daytona sandbox auto-deletes.
if stream_result and stream_result.sandbox_files:
@@ -2707,11 +3835,11 @@ async def stream_new_chat(
# Skip on ``BusyError`` (caller never acquired the lock).
if not _busy_error_raised:
with contextlib.suppress(Exception):
- if _release_busy_lock(str(chat_id)):
- _perf_log.info(
- "[stream_new_chat] released stale busy lock (chat_id=%s)",
- chat_id,
- )
+ end_turn(str(chat_id))
+ _perf_log.info(
+ "[stream_new_chat] end_turn cleanup (chat_id=%s)",
+ chat_id,
+ )
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
@@ -2765,83 +3893,218 @@ async def stream_resume_chat(
# Skip the finally release on ``BusyError`` (caller never acquired the lock).
_busy_error_raised = False
+ _emit_stream_error = partial(
+ _emit_stream_terminal_error,
+ streaming_service=streaming_service,
+ flow="resume",
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+
session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
agent_config: AgentConfig | None = None
- _t0 = time.perf_counter()
- if llm_config_id >= 0:
- agent_config = await load_agent_config(
- session=session,
- config_id=llm_config_id,
- search_space_id=search_space_id,
+ requested_llm_config_id = llm_config_id
+
+ async def _load_llm_bundle(
+ config_id: int,
+ ) -> tuple[Any, AgentConfig | None, str | None]:
+ if config_id >= 0:
+ loaded_agent_config = await load_agent_config(
+ session=session,
+ config_id=config_id,
+ search_space_id=search_space_id,
+ )
+ if not loaded_agent_config:
+ return (
+ None,
+ None,
+ f"Failed to load NewLLMConfig with id {config_id}",
+ )
+ return (
+ create_chat_litellm_from_agent_config(loaded_agent_config),
+ loaded_agent_config,
+ None,
+ )
+
+ loaded_llm_config = load_global_llm_config_by_id(config_id)
+ if not loaded_llm_config:
+ return None, None, f"Failed to load LLM config with id {config_id}"
+ return (
+ create_chat_litellm_from_config(loaded_llm_config),
+ AgentConfig.from_yaml_config(loaded_llm_config),
+ None,
)
- if not agent_config:
- yield streaming_service.format_error(
- f"Failed to load NewLLMConfig with id {llm_config_id}"
+
+ _t0 = time.perf_counter()
+ try:
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=llm_config_id,
)
- yield streaming_service.format_done()
- return
- llm = create_chat_litellm_from_agent_config(agent_config)
- else:
- llm_config = load_global_llm_config_by_id(llm_config_id)
- if not llm_config:
- yield streaming_service.format_error(
- f"Failed to load LLM config with id {llm_config_id}"
- )
- yield streaming_service.format_done()
- return
- llm = create_chat_litellm_from_config(llm_config)
- agent_config = AgentConfig.from_yaml_config(llm_config)
+ ).resolved_llm_config_id
+ except ValueError as pin_error:
+ yield _emit_stream_error(
+ message=str(pin_error),
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
+ if llm_load_error:
+ yield _emit_stream_error(
+ message=llm_load_error,
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
_perf_log.info(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
- # Premium quota reservation (same logic as stream_new_chat)
- _resume_premium_reserved = 0
+ # Premium credit reservation (same logic as stream_new_chat).
+ _resume_premium_reserved_micros = 0
_resume_premium_request_id: str | None = None
_resume_needs_premium = (
- agent_config is not None
- and user_id
- and (agent_config.is_premium or agent_config.is_auto_mode)
+ agent_config is not None and user_id and agent_config.is_premium
)
if _resume_needs_premium:
import uuid as _uuid
- from app.config import config as _app_config
- from app.services.token_quota_service import TokenQuotaService
+ from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+ )
_resume_premium_request_id = _uuid.uuid4().hex[:16]
- reserve_amount = min(
- agent_config.quota_reserve_tokens
- or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
- _app_config.QUOTA_MAX_RESERVE_PER_CALL,
+ _resume_litellm_params = agent_config.litellm_params or {}
+ _resume_base_model = (
+ _resume_litellm_params.get("base_model")
+ or agent_config.model_name
+ or ""
+ )
+ reserve_amount_micros = estimate_call_reserve_micros(
+ base_model=_resume_base_model,
+ quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
- reserve_tokens=reserve_amount,
+ reserve_micros=reserve_amount_micros,
)
- _resume_premium_reserved = reserve_amount
+ _resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
- if agent_config.is_premium:
- yield streaming_service.format_error(
- "Premium token quota exceeded. Please purchase more tokens to continue using premium models."
+ if requested_llm_config_id == 0:
+ try:
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=0,
+ force_repin_free=True,
+ )
+ ).resolved_llm_config_id
+ except ValueError as pin_error:
+ yield _emit_stream_error(
+ message=str(pin_error),
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(
+ llm_config_id
+ )
+ if llm_load_error:
+ yield _emit_stream_error(
+ message=llm_load_error,
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+ _resume_premium_request_id = None
+ _resume_premium_reserved_micros = 0
+ _log_chat_stream_error(
+ flow="resume",
+ error_kind="premium_quota_exhausted",
+ error_code="PREMIUM_QUOTA_EXHAUSTED",
+ severity="info",
+ is_expected=True,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=(
+ "Premium quota exhausted on pinned model; auto-fallback switched to a free model"
+ ),
+ extra={
+ "fallback_config_id": llm_config_id,
+ "auto_fallback": True,
+ },
+ )
+ else:
+ yield _emit_stream_error(
+ message=(
+ "Buy more tokens to continue with this model, or switch to a free model"
+ ),
+ error_kind="premium_quota_exhausted",
+ error_code="PREMIUM_QUOTA_EXHAUSTED",
+ severity="info",
+ is_expected=True,
+ extra={
+ "resolved_config_id": llm_config_id,
+ "auto_fallback": False,
+ },
)
yield streaming_service.format_done()
return
- _resume_premium_request_id = None
- _resume_premium_reserved = 0
if not llm:
- yield streaming_service.format_error("Failed to create LLM instance")
+ yield _emit_stream_error(
+ message="Failed to create LLM instance",
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
yield streaming_service.format_done()
return
+ # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
+ # one cheap probe before the agent is rebuilt so a 429'd pin gets
+ # repinned without burning planner/classifier/title calls first.
+ # See ``stream_new_chat`` for the full rationale on the speculative
+ # parallel build pattern below.
+ preflight_needed = (
+ requested_llm_config_id == 0
+ and llm_config_id < 0
+ and not is_recently_healthy(llm_config_id)
+ )
+ preflight_task: asyncio.Task[None] | None = None
+ _t_preflight = 0.0
+ if preflight_needed:
+ _t_preflight = time.perf_counter()
+ preflight_task = asyncio.create_task(
+ _preflight_llm(llm),
+ name=f"auto_pin_preflight_resume:{llm_config_id}",
+ )
+
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
@@ -2866,8 +4129,13 @@ async def stream_resume_chat(
from app.config import config as _app_config
_t0 = time.perf_counter()
- if _app_config.MULTI_AGENT_CHAT_ENABLED:
- agent = await create_registry_deep_agent(
+ agent_factory = (
+ create_registry_deep_agent
+ if _app_config.MULTI_AGENT_CHAT_ENABLED
+ else create_surfsense_deep_agent
+ )
+ agent_build_task = asyncio.create_task(
+ agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=session,
@@ -2880,22 +4148,99 @@ async def stream_resume_chat(
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
- )
- else:
- agent = await create_surfsense_deep_agent(
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- disabled_tools=disabled_tools,
- )
+ ),
+ name="agent_build:stream_resume",
+ )
+
+ agent: Any = None
+ if preflight_task is not None:
+ try:
+ await preflight_task
+ mark_healthy(llm_config_id)
+ _perf_log.info(
+ "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
+ llm_config_id,
+ time.perf_counter() - _t_preflight,
+ )
+ except Exception as preflight_exc:
+ # Same session-safety rationale as ``stream_new_chat``.
+ await _settle_speculative_agent_build(agent_build_task)
+ if not _is_provider_rate_limited(preflight_exc):
+ raise
+ previous_config_id = llm_config_id
+ mark_runtime_cooldown(
+ previous_config_id, reason="preflight_rate_limited"
+ )
+ try:
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=0,
+ exclude_config_ids={previous_config_id},
+ )
+ ).resolved_llm_config_id
+ except ValueError as pin_error:
+ yield _emit_stream_error(
+ message=str(pin_error),
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(
+ llm_config_id
+ )
+ if llm_load_error or not llm:
+ yield _emit_stream_error(
+ message=llm_load_error or "Failed to create LLM instance",
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ )
+ yield streaming_service.format_done()
+ return
+ mark_healthy(llm_config_id)
+ _log_chat_stream_error(
+ flow="resume",
+ error_kind="rate_limited",
+ error_code="RATE_LIMITED",
+ severity="info",
+ is_expected=True,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=(
+ "Auto-pinned model failed preflight; switched to another "
+ "eligible model and continuing."
+ ),
+ extra={
+ "auto_runtime_recover": True,
+ "preflight": True,
+ "previous_config_id": previous_config_id,
+ "fallback_config_id": llm_config_id,
+ },
+ )
+ agent = await agent_factory(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ filesystem_selection=filesystem_selection,
+ disabled_tools=disabled_tools,
+ )
+
+ if agent is None:
+ agent = await agent_build_task
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
@@ -2937,34 +4282,174 @@ async def stream_resume_chat(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
+ yield streaming_service.format_data("turn-status", {"status": "busy"})
+
+ # Pre-write a fresh assistant row for this resume turn. The
+ # original (interrupted) ``stream_new_chat`` invocation already
+ # persisted its own assistant row anchored to a different
+ # ``turn_id``; resume allocates a new ``turn_id`` (above) so we
+ # need a separate row keyed on the same ``(thread_id, turn_id,
+ # ASSISTANT)`` invariant. Idempotent against migration 141's
+ # partial unique index — recovers existing id on retry.
+ from app.tasks.chat.content_builder import AssistantContentBuilder
+ from app.tasks.chat.persistence import persist_assistant_shell
+
+ assistant_message_id = await persist_assistant_shell(
+ chat_id=chat_id,
+ user_id=user_id,
+ turn_id=stream_result.turn_id,
+ )
+ if assistant_message_id is None:
+ yield _emit_stream_error(
+ message=(
+ "We couldn't initialize the assistant message. Please try again."
+ ),
+ error_kind="server_error",
+ error_code="MESSAGE_PERSIST_FAILED",
+ )
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
+ yield streaming_service.format_finish_step()
+ yield streaming_service.format_finish()
+ yield streaming_service.format_done()
+ return
+
+ # Emit canonical assistant message id BEFORE any LLM streaming
+ # so the FE can rename ``pendingInterrupt.assistantMsgId`` to
+ # ``msg-{assistant_message_id}`` immediately. Resume does NOT
+ # emit ``data-user-message-id`` because the user row is from
+ # the original interrupted turn (different ``turn_id``) and is
+ # never re-persisted here. See B5 in the
+ # ``sse-based_message_id_handshake`` plan.
+ yield streaming_service.format_data(
+ "assistant-message-id",
+ {"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
+ )
+
+ stream_result.assistant_message_id = assistant_message_id
+ stream_result.content_builder = AssistantContentBuilder()
+
+ # Resume path doesn't carry new ``mentioned_document_ids`` —
+ # those are seeded in the original turn. We still pass a
+ # context so future middleware extensions (Phase 2) can rely on
+ # ``runtime.context`` always being populated.
+ runtime_context = SurfSenseContextSchema(
+ search_space_id=search_space_id,
+ request_id=request_id,
+ turn_id=stream_result.turn_id,
+ )
_t_stream_start = time.perf_counter()
_first_event_logged = False
- async for sse in _stream_agent_events(
- agent=agent,
- config=config,
- input_data=Command(resume={"decisions": decisions}),
- streaming_service=streaming_service,
- result=stream_result,
- step_prefix="thinking-resume",
- fallback_commit_search_space_id=search_space_id,
- fallback_commit_created_by_id=user_id,
- fallback_commit_filesystem_mode=(
- filesystem_selection.mode
- if filesystem_selection
- else FilesystemMode.CLOUD
- ),
- fallback_commit_thread_id=chat_id,
- ):
- if not _first_event_logged:
- _perf_log.info(
- "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
- time.perf_counter() - _t_stream_start,
- time.perf_counter() - _t_total,
- chat_id,
+ runtime_rate_limit_recovered = False
+ while True:
+ try:
+ async for sse in _stream_agent_events(
+ agent=agent,
+ config=config,
+ input_data=Command(resume={"decisions": decisions}),
+ streaming_service=streaming_service,
+ result=stream_result,
+ step_prefix="thinking-resume",
+ fallback_commit_search_space_id=search_space_id,
+ fallback_commit_created_by_id=user_id,
+ fallback_commit_filesystem_mode=(
+ filesystem_selection.mode
+ if filesystem_selection
+ else FilesystemMode.CLOUD
+ ),
+ fallback_commit_thread_id=chat_id,
+ runtime_context=runtime_context,
+ content_builder=stream_result.content_builder,
+ ):
+ if not _first_event_logged:
+ _perf_log.info(
+ "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
+ time.perf_counter() - _t_stream_start,
+ time.perf_counter() - _t_total,
+ chat_id,
+ )
+ _first_event_logged = True
+ yield sse
+ break
+ except Exception as stream_exc:
+ can_runtime_recover = (
+ not runtime_rate_limit_recovered
+ and requested_llm_config_id == 0
+ and llm_config_id < 0
+ and not _first_event_logged
+ and _is_provider_rate_limited(stream_exc)
)
- _first_event_logged = True
- yield sse
+ if not can_runtime_recover:
+ raise
+
+ runtime_rate_limit_recovered = True
+ previous_config_id = llm_config_id
+ # Ensure the same-request recovery retry does not trip the
+ # BusyMutex lock retained by the failed attempt.
+ end_turn(str(chat_id))
+ mark_runtime_cooldown(
+ previous_config_id,
+ reason="provider_rate_limited",
+ )
+ llm_config_id = (
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ selected_llm_config_id=0,
+ exclude_config_ids={previous_config_id},
+ )
+ ).resolved_llm_config_id
+
+ llm, agent_config, llm_load_error = await _load_llm_bundle(
+ llm_config_id
+ )
+ if llm_load_error:
+ raise stream_exc
+
+ _t0 = time.perf_counter()
+ agent = await create_surfsense_deep_agent(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ filesystem_selection=filesystem_selection,
+ )
+ _perf_log.info(
+ "[stream_resume] Runtime rate-limit recovery repinned "
+ "config_id=%s -> %s and rebuilt agent in %.3fs",
+ previous_config_id,
+ llm_config_id,
+ time.perf_counter() - _t0,
+ )
+ _log_chat_stream_error(
+ flow="resume",
+ error_kind="rate_limited",
+ error_code="RATE_LIMITED",
+ severity="info",
+ is_expected=True,
+ request_id=request_id,
+ thread_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ message=(
+ "Auto-pinned model hit runtime rate limit; switched to "
+ "another eligible model and retried."
+ ),
+ extra={
+ "auto_runtime_recover": True,
+ "previous_config_id": previous_config_id,
+ "fallback_config_id": llm_config_id,
+ },
+ )
+ continue
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
@@ -2973,9 +4458,10 @@ async def stream_resume_chat(
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
+ "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -2986,6 +4472,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -2995,28 +4482,23 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
- # Finalize premium quota for resume path
+ # Finalize premium credit debit for resume path with the actual
+ # provider cost reported by LiteLLM (sum of cost across all
+ # calls in the turn).
if _resume_premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
- if agent_config and agent_config.is_auto_mode:
- from app.services.llm_router_service import LLMRouterService
-
- actual_premium_tokens = LLMRouterService.compute_premium_tokens(
- accumulator.calls
- )
- else:
- actual_premium_tokens = accumulator.grand_total
-
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
- actual_tokens=actual_premium_tokens,
- reserved_tokens=_resume_premium_reserved,
+ actual_micros=accumulator.total_cost_micros,
+ reserved_micros=_resume_premium_reserved_micros,
)
+ _resume_premium_request_id = None
+ _resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)",
@@ -3026,9 +4508,10 @@ async def stream_resume_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
+ "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3039,17 +4522,12 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
-
- except BusyError as e:
- _busy_error_raised = True
- yield streaming_service.format_error(str(e))
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@@ -3057,18 +4535,55 @@ async def stream_resume_chat(
except Exception as e:
import traceback
+ # ``BusyError`` fires before the agent acquires the lock; the
+ # cleanup path must skip lock release to avoid freeing the
+ # in-flight caller's lock. Classification is handled below.
+ if isinstance(e, BusyError):
+ _busy_error_raised = True
+
+ (
+ error_kind,
+ error_code,
+ severity,
+ is_expected,
+ user_message,
+ error_extra,
+ ) = _classify_stream_exception(e, flow_label="resume")
error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
- yield streaming_service.format_error(error_message)
+ if error_code == "TURN_CANCELLING":
+ status_payload: dict[str, Any] = {"status": "cancelling"}
+ if error_extra:
+ status_payload.update(error_extra)
+ yield streaming_service.format_data("turn-status", status_payload)
+ else:
+ yield streaming_service.format_data("turn-status", {"status": "busy"})
+ yield _emit_stream_error(
+ message=user_message,
+ error_kind=error_kind,
+ error_code=error_code,
+ severity=severity,
+ is_expected=is_expected,
+ extra=error_extra,
+ )
+ yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
finally:
with anyio.CancelScope(shield=True):
+ # Authoritative fallback cleanup for lock/cancel state. Middleware
+ # teardown can be skipped on some client-abort paths.
+ end_turn(str(chat_id))
+
# Release premium reservation if not finalized
- if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
+ if (
+ _resume_premium_request_id
+ and _resume_premium_reserved_micros > 0
+ and user_id
+ ):
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3076,9 +4591,9 @@ async def stream_resume_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
- reserved_tokens=_resume_premium_reserved,
+ reserved_micros=_resume_premium_reserved_micros,
)
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s (resume)", user_id
@@ -3102,15 +4617,73 @@ async def stream_resume_chat(
with contextlib.suppress(Exception):
await session.close()
+ # Server-side assistant-message + token_usage finalization for
+ # the resume flow. The original user message was persisted by
+ # the original (interrupted) ``stream_new_chat`` invocation;
+ # the resume's own ``persist_assistant_shell`` write lives at
+ # the new ``turn_id`` above. This finalize updates that row
+ # with the rich ContentPart[] from the builder and writes
+ # token_usage idempotently via migration 142's partial
+ # unique index. Best-effort, never raises.
+ if (
+ stream_result
+ and stream_result.turn_id
+ and stream_result.assistant_message_id
+ ):
+ from app.tasks.chat.persistence import finalize_assistant_turn
+
+ builder_stats: dict[str, int] | None = None
+ if stream_result.content_builder is not None:
+ stream_result.content_builder.mark_interrupted()
+ builder_stats = stream_result.content_builder.stats()
+ content_payload = stream_result.content_builder.snapshot()
+ else:
+ content_payload = [
+ {
+ "type": "text",
+ "text": stream_result.accumulated_text or "",
+ }
+ ]
+
+ if builder_stats is not None:
+ _perf_log.info(
+ "[stream_resume] finalize_payload chat_id=%s "
+ "message_id=%s parts=%d bytes=%d text=%d "
+ "reasoning=%d tool_calls=%d "
+ "tool_calls_completed=%d tool_calls_aborted=%d "
+ "thinking_step_parts=%d step_separators=%d",
+ chat_id,
+ stream_result.assistant_message_id,
+ builder_stats["parts"],
+ builder_stats["bytes"],
+ builder_stats["text"],
+ builder_stats["reasoning"],
+ builder_stats["tool_calls"],
+ builder_stats["tool_calls_completed"],
+ builder_stats["tool_calls_aborted"],
+ builder_stats["thinking_step_parts"],
+ builder_stats["step_separators"],
+ )
+
+ await finalize_assistant_turn(
+ message_id=stream_result.assistant_message_id,
+ chat_id=chat_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ turn_id=stream_result.turn_id,
+ content=content_payload,
+ accumulator=accumulator,
+ )
+
# Release the lock from the original interrupted turn or any
# re-interrupt/bailout. Skip on ``BusyError`` (lock not held here).
if not _busy_error_raised:
with contextlib.suppress(Exception):
- if _release_busy_lock(str(chat_id)):
- _perf_log.info(
- "[stream_resume] released stale busy lock (chat_id=%s)",
- chat_id,
- )
+ end_turn(str(chat_id))
+ _perf_log.info(
+ "[stream_resume] end_turn cleanup (chat_id=%s)",
+ chat_id,
+ )
agent = llm = connector_service = None
stream_result = None
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
index 6912ffe5a..3c9f27303 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
@@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
+from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
-from app.utils.google_credentials import (
- COMPOSIO_GOOGLE_CONNECTOR_TYPES,
- build_composio_credentials,
-)
+from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
from .base import (
check_duplicate_document_by_hash,
@@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
+def _format_calendar_event_to_markdown(event: dict) -> str:
+ return GoogleCalendarConnector.format_event_to_markdown(None, event)
+
+
def _build_connector_doc(
event: dict,
event_markdown: str,
@@ -150,7 +152,14 @@ async def index_google_calendar_events(
)
return 0, 0, f"Connector with ID {connector_id} not found"
- # ── Credential building ───────────────────────────────────────
+ is_composio_connector = (
+ connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
+ )
+ calendar_client = None
+ composio_service = None
+ connected_account_id = None
+
+ # ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@@ -161,7 +170,7 @@ async def index_google_calendar_events(
{"error_type": "MissingComposioAccount"},
)
return 0, 0, "Composio connected_account_id not found"
- credentials = build_composio_credentials(connected_account_id)
+ composio_service = ComposioService()
else:
config_data = connector.config
@@ -229,12 +238,13 @@ async def index_google_calendar_events(
{"stage": "client_initialization"},
)
- calendar_client = GoogleCalendarConnector(
- credentials=credentials,
- session=session,
- user_id=user_id,
- connector_id=connector_id,
- )
+ if not is_composio_connector:
+ calendar_client = GoogleCalendarConnector(
+ credentials=credentials,
+ session=session,
+ user_id=user_id,
+ connector_id=connector_id,
+ )
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
@@ -300,9 +310,26 @@ async def index_google_calendar_events(
)
try:
- events, error = await calendar_client.get_all_primary_calendar_events(
- start_date=start_date_str, end_date=end_date_str
- )
+ if is_composio_connector:
+ start_dt = parse_date_flexible(start_date_str).replace(
+ hour=0, minute=0, second=0, microsecond=0
+ )
+ end_dt = parse_date_flexible(end_date_str).replace(
+ hour=23, minute=59, second=59, microsecond=0
+ )
+ events, error = await composio_service.get_calendar_events(
+ connected_account_id=connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ time_min=start_dt.isoformat(),
+ time_max=end_dt.isoformat(),
+ max_results=250,
+ )
+ if not events and not error:
+ error = "No events found in the specified date range."
+ else:
+ events, error = await calendar_client.get_all_primary_calendar_events(
+ start_date=start_date_str, end_date=end_date_str
+ )
if error:
if "No events found" in error:
@@ -381,7 +408,7 @@ async def index_google_calendar_events(
documents_skipped += 1
continue
- event_markdown = calendar_client.format_event_to_markdown(event)
+ event_markdown = _format_calendar_event_to_markdown(event)
if not event_markdown.strip():
logger.warning(f"Skipping event with no content: {event_summary}")
documents_skipped += 1
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
index 21cdbd29f..686f13d9e 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
@@ -9,6 +9,8 @@ import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
+from pathlib import Path
+from typing import Any
from sqlalchemy import String, cast, select
from sqlalchemy.exc import SQLAlchemyError
@@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
+from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService
@@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
get_connector_by_id,
update_connector_last_indexed,
)
-from app.utils.google_credentials import (
- COMPOSIO_GOOGLE_CONNECTOR_TYPES,
- build_composio_credentials,
-)
+from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
@@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
logger = logging.getLogger(__name__)
+class ComposioDriveClient:
+ """Google Drive client facade backed by Composio tool execution.
+
+ Composio-managed OAuth connections can execute tools without exposing raw
+ OAuth tokens through connected account state.
+ """
+
+ def __init__(
+ self,
+ session: AsyncSession,
+ connector_id: int,
+ connected_account_id: str,
+ entity_id: str,
+ ):
+ self.session = session
+ self.connector_id = connector_id
+ self.connected_account_id = connected_account_id
+ self.entity_id = entity_id
+ self.composio = ComposioService()
+
+ async def list_files(
+ self,
+ query: str = "",
+ fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
+ page_size: int = 100,
+ page_token: str | None = None,
+ ) -> tuple[list[dict[str, Any]], str | None, str | None]:
+ params: dict[str, Any] = {
+ "page_size": min(page_size, 100),
+ "fields": fields,
+ }
+ if query:
+ params["q"] = query
+ if page_token:
+ params["page_token"] = page_token
+
+ result = await self.composio.execute_tool(
+ connected_account_id=self.connected_account_id,
+ tool_name="GOOGLEDRIVE_LIST_FILES",
+ params=params,
+ entity_id=self.entity_id,
+ )
+ if not result.get("success"):
+ return [], None, result.get("error", "Unknown error")
+
+ data = result.get("data", {})
+ files = []
+ next_token = None
+ if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ files = inner_data.get("files", [])
+ next_token = inner_data.get("nextPageToken") or inner_data.get(
+ "next_page_token"
+ )
+ elif isinstance(data, list):
+ files = data
+
+ return files, next_token, None
+
+ async def get_file_metadata(
+ self, file_id: str, fields: str = "*"
+ ) -> tuple[dict[str, Any] | None, str | None]:
+ result = await self.composio.execute_tool(
+ connected_account_id=self.connected_account_id,
+ tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
+ params={"file_id": file_id, "fields": fields},
+ entity_id=self.entity_id,
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown error")
+
+ data = result.get("data", {})
+ if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ return inner_data, None
+
+ return None, "Could not extract metadata from Composio response"
+
+ async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
+ return await self._download_file_content(file_id)
+
+ async def download_file_to_disk(
+ self,
+ file_id: str,
+ dest_path: str,
+ chunksize: int = 5 * 1024 * 1024,
+ ) -> str | None:
+ del chunksize
+ content, error = await self.download_file(file_id)
+ if error:
+ return error
+ if content is None:
+ return "No content returned from Composio"
+ Path(dest_path).write_bytes(content)
+ return None
+
+ async def export_google_file(
+ self, file_id: str, mime_type: str
+ ) -> tuple[bytes | None, str | None]:
+ return await self._download_file_content(file_id, mime_type=mime_type)
+
+ async def _download_file_content(
+ self, file_id: str, mime_type: str | None = None
+ ) -> tuple[bytes | None, str | None]:
+ params: dict[str, Any] = {"file_id": file_id}
+ if mime_type:
+ params["mime_type"] = mime_type
+
+ result = await self.composio.execute_tool(
+ connected_account_id=self.connected_account_id,
+ tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
+ params=params,
+ entity_id=self.entity_id,
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown error")
+
+ return self._read_download_result(result.get("data"))
+
+ def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
+ if isinstance(data, bytes):
+ return data, None
+
+ file_path: str | None = None
+ if isinstance(data, str):
+ file_path = data
+ elif isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ for key in ("file_path", "downloaded_file_content", "path", "uri"):
+ value = inner_data.get(key)
+ if isinstance(value, str):
+ file_path = value
+ break
+ if isinstance(value, dict):
+ nested = (
+ value.get("file_path")
+ or value.get("downloaded_file_content")
+ or value.get("path")
+ or value.get("uri")
+ or value.get("s3url")
+ )
+ if isinstance(nested, str):
+ file_path = nested
+ break
+
+ if not file_path:
+ return None, "No file path/content returned from Composio"
+
+ if file_path.startswith(("http://", "https://")):
+ try:
+ import urllib.request
+
+ with urllib.request.urlopen(file_path, timeout=60) as response:
+ return response.read(), None
+ except Exception as e:
+ return None, f"Failed to download Composio file URL: {e!s}"
+
+ path_obj = Path(file_path)
+ if path_obj.is_absolute() or ".composio" in str(path_obj):
+ if not path_obj.exists():
+ return None, f"File not found at path: {file_path}"
+ return path_obj.read_bytes(), None
+
+ try:
+ import base64
+
+ return base64.b64decode(file_path), None
+ except Exception:
+ return file_path.encode("utf-8"), None
+
+
+def _build_drive_client_for_connector(
+ session: AsyncSession,
+ connector_id: int,
+ connector: object,
+ user_id: str,
+) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
+ if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
+ connected_account_id = connector.config.get("composio_connected_account_id")
+ if not connected_account_id:
+ return None, (
+ f"Composio connected_account_id not found for connector {connector_id}"
+ )
+ return (
+ ComposioDriveClient(
+ session,
+ connector_id,
+ connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ ),
+ None,
+ )
+
+ token_encrypted = connector.config.get("_token_encrypted", False)
+ if token_encrypted and not config.SECRET_KEY:
+ return None, "SECRET_KEY not configured but credentials are marked as encrypted"
+
+ return GoogleDriveClient(session, connector_id), None
+
+
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -927,34 +1130,17 @@ async def index_google_drive_files(
{"stage": "client_initialization"},
)
- pre_built_credentials = None
- if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- connected_account_id = connector.config.get("composio_connected_account_id")
- if not connected_account_id:
- error_msg = f"Composio connected_account_id not found for connector {connector_id}"
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing Composio account",
- {"error_type": "MissingComposioAccount"},
- )
- return 0, 0, error_msg, 0
- pre_built_credentials = build_composio_credentials(connected_account_id)
- else:
- token_encrypted = connector.config.get("_token_encrypted", False)
- if token_encrypted and not config.SECRET_KEY:
- await task_logger.log_task_failure(
- log_entry,
- "SECRET_KEY not configured but credentials are encrypted",
- "Missing SECRET_KEY",
- {"error_type": "MissingSecretKey"},
- )
- return (
- 0,
- 0,
- "SECRET_KEY not configured but credentials are marked as encrypted",
- 0,
- )
+ drive_client, client_error = _build_drive_client_for_connector(
+ session, connector_id, connector, user_id
+ )
+ if client_error or not drive_client:
+ await task_logger.log_task_failure(
+ log_entry,
+ client_error or "Failed to initialize Google Drive client",
+ "Missing connector credentials",
+ {"error_type": "ClientInitializationError"},
+ )
+ return 0, 0, client_error, 0
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@@ -963,10 +1149,6 @@ async def index_google_drive_files(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
- drive_client = GoogleDriveClient(
- session, connector_id, credentials=pre_built_credentials
- )
-
if not folder_id:
error_msg = "folder_id is required for Google Drive indexing"
await task_logger.log_task_failure(
@@ -979,8 +1161,14 @@ async def index_google_drive_files(
folder_tokens = connector.config.get("folder_tokens", {})
start_page_token = folder_tokens.get(target_folder_id)
+ is_composio_connector = (
+ connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
+ )
can_use_delta = (
- use_delta_sync and start_page_token and connector.last_indexed_at
+ not is_composio_connector
+ and use_delta_sync
+ and start_page_token
+ and connector.last_indexed_at
)
documents_unsupported = 0
@@ -1051,7 +1239,16 @@ async def index_google_drive_files(
)
if documents_indexed > 0 or can_use_delta:
- new_token, token_error = await get_start_page_token(drive_client)
+ if isinstance(drive_client, ComposioDriveClient):
+ (
+ new_token,
+ token_error,
+ ) = await drive_client.composio.get_drive_start_page_token(
+ drive_client.connected_account_id,
+ drive_client.entity_id,
+ )
+ else:
+ new_token, token_error = await get_start_page_token(drive_client)
if new_token and not token_error:
await session.refresh(connector)
if "folder_tokens" not in connector.config:
@@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
)
return 0, error_msg
- pre_built_credentials = None
- if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- connected_account_id = connector.config.get("composio_connected_account_id")
- if not connected_account_id:
- error_msg = f"Composio connected_account_id not found for connector {connector_id}"
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing Composio account",
- {"error_type": "MissingComposioAccount"},
- )
- return 0, error_msg
- pre_built_credentials = build_composio_credentials(connected_account_id)
- else:
- token_encrypted = connector.config.get("_token_encrypted", False)
- if token_encrypted and not config.SECRET_KEY:
- await task_logger.log_task_failure(
- log_entry,
- "SECRET_KEY not configured but credentials are encrypted",
- "Missing SECRET_KEY",
- {"error_type": "MissingSecretKey"},
- )
- return (
- 0,
- "SECRET_KEY not configured but credentials are marked as encrypted",
- )
+ drive_client, client_error = _build_drive_client_for_connector(
+ session, connector_id, connector, user_id
+ )
+ if client_error or not drive_client:
+ await task_logger.log_task_failure(
+ log_entry,
+ client_error or "Failed to initialize Google Drive client",
+ "Missing connector credentials",
+ {"error_type": "ClientInitializationError"},
+ )
+ return 0, client_error
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
- drive_client = GoogleDriveClient(
- session, connector_id, credentials=pre_built_credentials
- )
-
file, error = await get_file_by_id(drive_client, file_id)
if error or not file:
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
@@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files(
)
return 0, 0, [error_msg]
- pre_built_credentials = None
- if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- connected_account_id = connector.config.get("composio_connected_account_id")
- if not connected_account_id:
- error_msg = f"Composio connected_account_id not found for connector {connector_id}"
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing Composio account",
- {"error_type": "MissingComposioAccount"},
- )
- return 0, 0, [error_msg]
- pre_built_credentials = build_composio_credentials(connected_account_id)
- else:
- token_encrypted = connector.config.get("_token_encrypted", False)
- if token_encrypted and not config.SECRET_KEY:
- error_msg = (
- "SECRET_KEY not configured but credentials are marked as encrypted"
- )
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing SECRET_KEY",
- {"error_type": "MissingSecretKey"},
- )
- return 0, 0, [error_msg]
+ drive_client, client_error = _build_drive_client_for_connector(
+ session, connector_id, connector, user_id
+ )
+ if client_error or not drive_client:
+ error_msg = client_error or "Failed to initialize Google Drive client"
+ await task_logger.log_task_failure(
+ log_entry,
+ error_msg,
+ "Missing connector credentials",
+ {"error_type": "ClientInitializationError"},
+ )
+ return 0, 0, [error_msg]
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
- drive_client = GoogleDriveClient(
- session, connector_id, credentials=pre_built_credentials
- )
-
indexed, skipped, unsupported, errors = await _index_selected_files(
drive_client,
session,
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
index ef226087b..6697c0eb1 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
@@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
+from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
-from app.utils.google_credentials import (
- COMPOSIO_GOOGLE_CONNECTOR_TYPES,
- build_composio_credentials,
-)
+from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
from .base import (
calculate_date_range,
@@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
+def _normalize_composio_gmail_message(message: dict) -> dict:
+ if message.get("payload"):
+ return message
+
+ headers = []
+ header_values = {
+ "Subject": message.get("subject"),
+ "From": message.get("from") or message.get("sender"),
+ "To": message.get("to") or message.get("recipient"),
+ "Date": message.get("date"),
+ }
+ for name, value in header_values.items():
+ if value:
+ headers.append({"name": name, "value": value})
+
+ return {
+ **message,
+ "id": message.get("id")
+ or message.get("message_id")
+ or message.get("messageId"),
+ "threadId": message.get("threadId") or message.get("thread_id"),
+ "payload": {"headers": headers},
+ "snippet": message.get("snippet", ""),
+ "messageText": message.get("messageText") or message.get("body") or "",
+ }
+
+
+def _format_gmail_message_to_markdown(message: dict) -> str:
+ headers = {
+ header.get("name", "").lower(): header.get("value", "")
+ for header in message.get("payload", {}).get("headers", [])
+ if isinstance(header, dict)
+ }
+ subject = headers.get("subject", "No Subject")
+ from_email = headers.get("from", "Unknown Sender")
+ to_email = headers.get("to", "Unknown Recipient")
+ date_str = headers.get("date", "Unknown Date")
+ message_text = (
+ message.get("messageText")
+ or message.get("body")
+ or message.get("text")
+ or message.get("snippet", "")
+ )
+
+ return (
+ f"# {subject}\n\n"
+ f"**From:** {from_email}\n"
+ f"**To:** {to_email}\n"
+ f"**Date:** {date_str}\n\n"
+ f"## Message Content\n\n{message_text}\n\n"
+ f"## Message Details\n\n"
+ f"- **Message ID:** {message.get('id', 'Unknown')}\n"
+ f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
+ )
+
+
def _build_connector_doc(
message: dict,
markdown_content: str,
@@ -162,7 +216,14 @@ async def index_google_gmail_messages(
)
return 0, 0, error_msg
- # ── Credential building ───────────────────────────────────────
+ is_composio_connector = (
+ connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
+ )
+ gmail_connector = None
+ composio_service = None
+ connected_account_id = None
+
+ # ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@@ -173,7 +234,7 @@ async def index_google_gmail_messages(
{"error_type": "MissingComposioAccount"},
)
return 0, 0, "Composio connected_account_id not found"
- credentials = build_composio_credentials(connected_account_id)
+ composio_service = ComposioService()
else:
config_data = connector.config
@@ -241,9 +302,10 @@ async def index_google_gmail_messages(
{"stage": "client_initialization"},
)
- gmail_connector = GoogleGmailConnector(
- credentials, session, user_id, connector_id
- )
+ if not is_composio_connector:
+ gmail_connector = GoogleGmailConnector(
+ credentials, session, user_id, connector_id
+ )
calculated_start_date, calculated_end_date = calculate_date_range(
connector, start_date, end_date, default_days_back=365
@@ -254,11 +316,60 @@ async def index_google_gmail_messages(
f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}"
)
- messages, error = await gmail_connector.get_recent_messages(
- max_results=max_messages,
- start_date=calculated_start_date,
- end_date=calculated_end_date,
- )
+ if is_composio_connector:
+ query_parts = []
+ if calculated_start_date:
+ query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
+ if calculated_end_date:
+ query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
+ query = " ".join(query_parts)
+
+ messages = []
+ page_token = None
+ error = None
+ while len(messages) < max_messages:
+ page_size = min(50, max_messages - len(messages))
+ (
+ page_messages,
+ page_token,
+ _estimate,
+ page_error,
+ ) = await composio_service.get_gmail_messages(
+ connected_account_id=connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ query=query,
+ max_results=page_size,
+ page_token=page_token,
+ )
+ if page_error:
+ error = page_error
+ break
+ for page_message in page_messages:
+ message_id = (
+ page_message.get("id")
+ or page_message.get("message_id")
+ or page_message.get("messageId")
+ )
+ if message_id:
+ (
+ detail,
+ detail_error,
+ ) = await composio_service.get_gmail_message_detail(
+ connected_account_id=connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ message_id=message_id,
+ )
+ if not detail_error and isinstance(detail, dict):
+ page_message = detail
+ messages.append(_normalize_composio_gmail_message(page_message))
+ if not page_token:
+ break
+ else:
+ messages, error = await gmail_connector.get_recent_messages(
+ max_results=max_messages,
+ start_date=calculated_start_date,
+ end_date=calculated_end_date,
+ )
if error:
error_message = error
@@ -326,7 +437,12 @@ async def index_google_gmail_messages(
documents_skipped += 1
continue
- markdown_content = gmail_connector.format_message_to_markdown(message)
+ if is_composio_connector:
+ markdown_content = _format_gmail_message_to_markdown(message)
+ else:
+ markdown_content = gmail_connector.format_message_to_markdown(
+ message
+ )
if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1
diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml
index 131627386..da8c4b7d1 100644
--- a/surfsense_backend/pyproject.toml
+++ b/surfsense_backend/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "surf-new-backend"
-version = "0.0.19"
+version = "0.0.21"
description = "SurfSense Backend"
requires-python = ">=3.12"
dependencies = [
@@ -71,11 +71,11 @@ dependencies = [
"langchain>=1.2.13",
"langgraph>=1.1.3",
"langchain-community>=0.4.1",
- "deepagents>=0.4.12",
"stripe>=15.0.0",
"azure-ai-documentintelligence>=1.0.2",
- "litellm>=1.83.4",
+ "litellm>=1.83.7",
"langchain-litellm>=0.6.4",
+ "deepagents>=0.4.12,<0.5",
]
[dependency-groups]
diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py
new file mode 100644
index 000000000..a49d4eab2
--- /dev/null
+++ b/surfsense_backend/scripts/verify_chat_image_capability.py
@@ -0,0 +1,558 @@
+"""End-to-end smoke test for vision / image config wiring.
+
+Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and
+exercises every chat / vision / image-generation config + the OpenRouter
+dynamic catalog. For each config the script:
+
+1. Reports the resolver classification (catalog-allow vs strict-block).
+2. Optionally fires a tiny live API call against the provider:
+ - Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt
+ ``"reply with one word: ok"``.
+ - Vision configs: same, against the dedicated vision router pool.
+ - Image-gen configs: ``litellm.aimage_generation`` with a single tiny
+ prompt and ``n=1``.
+ - OpenRouter integration: samples one chat, one vision, one image-gen
+ model from the dynamically fetched catalog.
+
+Usage::
+
+ python -m scripts.verify_chat_image_capability # capability + connectivity
+ python -m scripts.verify_chat_image_capability --no-live # capability resolver only
+
+The script is meant to be runnable from the repository root or from
+``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the
+end so it's usable as a CI smoke check too.
+
+Live-mode caveat: each successful call costs a small amount of provider
+credit (a few tokens or one tiny generated image per config). The
+default size for image generation is ``1024x1024`` because Azure
+GPT-image deployments reject smaller sizes; OpenRouter image-gen models
+generally accept the same size.
+"""
+
+from __future__ import annotations
+
+import argparse
+import asyncio
+import logging
+import os
+import sys
+import time
+from dataclasses import dataclass, field
+from typing import Any
+
+# Bootstrap the surfsense_backend package on sys.path so the script runs
+# from the repo root or from `surfsense_backend/` interchangeably.
+_HERE = os.path.dirname(os.path.abspath(__file__))
+_BACKEND_ROOT = os.path.dirname(_HERE)
+if _BACKEND_ROOT not in sys.path:
+ sys.path.insert(0, _BACKEND_ROOT)
+
+import litellm # noqa: E402
+
+from app.config import config # noqa: E402
+from app.services.openrouter_integration_service import ( # noqa: E402
+ _OPENROUTER_DYNAMIC_MARKER,
+ OpenRouterIntegrationService,
+)
+from app.services.provider_api_base import resolve_api_base # noqa: E402
+from app.services.provider_capabilities import ( # noqa: E402
+ derive_supports_image_input,
+ is_known_text_only_chat_model,
+)
+
+logging.basicConfig(
+ level=logging.WARNING,
+ format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
+)
+# Quiet down LiteLLM's verbose router/cost logs so the script output is
+# scannable.
+logging.getLogger("LiteLLM").setLevel(logging.ERROR)
+logging.getLogger("litellm").setLevel(logging.ERROR)
+logging.getLogger("httpx").setLevel(logging.ERROR)
+
+# 1x1 transparent PNG — used as the cheapest possible vision payload.
+_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
+_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}"
+
+
+# ---------------------------------------------------------------------------
+# Result accounting
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ProbeResult:
+ label: str
+ surface: str
+ config_id: int | str
+ capability_ok: bool | None = None
+ capability_note: str = ""
+ live_ok: bool | None = None
+ live_note: str = ""
+ duration_s: float = 0.0
+
+
+@dataclass
+class Report:
+ results: list[ProbeResult] = field(default_factory=list)
+
+ def add(self, r: ProbeResult) -> None:
+ self.results.append(r)
+
+ def render(self) -> int:
+ passed = failed = skipped = 0
+ print()
+ print("=" * 92)
+ print(
+ f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes"
+ )
+ print("-" * 92)
+ for r in self.results:
+
+ def _flag(value: bool | None) -> str:
+ if value is None:
+ return "skip"
+ return "ok" if value else "fail"
+
+ cap = _flag(r.capability_ok)
+ live = _flag(r.live_ok)
+ if r.capability_ok is False or r.live_ok is False:
+ failed += 1
+ elif r.capability_ok is None and r.live_ok is None:
+ skipped += 1
+ else:
+ passed += 1
+ print(
+ f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} "
+ f"{r.duration_s:>5.2f}s {r.label}"
+ )
+ if r.capability_note:
+ print(f" cap: {r.capability_note}")
+ if r.live_note:
+ print(f" live: {r.live_note}")
+ print("-" * 92)
+ print(
+ f"Total: {passed} ok / {failed} fail / {skipped} skip "
+ f"(of {len(self.results)} probes)"
+ )
+ print("=" * 92)
+ return failed
+
+
+# ---------------------------------------------------------------------------
+# Capability probes (no network)
+# ---------------------------------------------------------------------------
+
+
+def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
+ """For chat configs the catalog flag is *expected* True (vision-capable
+ pool). The probe reports both the resolver value and the strict
+ safety-net value to surface any drift between them."""
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = (
+ litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
+ )
+ cap = derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+ block = is_known_text_only_chat_model(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+ note = f"derive={cap} strict_block={block}"
+ if not cap and not block:
+ # Resolver said False but strict gate is also False — that means
+ # OR modalities published [text] explicitly. Surface it.
+ note += " (OR modality says text-only)"
+ # We accept a True derive *or* (False derive AND False block) as
+ # 'capability ok' — either way, the streaming task will flow through.
+ ok = cap or not block
+ return ok, note
+
+
+def _build_chat_model_string(cfg: dict) -> str:
+ if cfg.get("custom_provider"):
+ return f"{cfg['custom_provider']}/{cfg['model_name']}"
+ from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
+
+ prefix = _PROVIDER_PREFIX_MAP.get(
+ (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
+ )
+ return f"{prefix}/{cfg['model_name']}"
+
+
+# ---------------------------------------------------------------------------
+# Live probes (network calls)
+# ---------------------------------------------------------------------------
+
+
+async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
+ """Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
+ model_string = _build_chat_model_string(cfg)
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=model_string.split("/", 1)[0],
+ config_api_base=cfg.get("api_base") or None,
+ )
+ kwargs: dict[str, Any] = {
+ "model": model_string,
+ "api_key": cfg.get("api_key"),
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "reply with one word: ok"},
+ {
+ "type": "image_url",
+ "image_url": {"url": _TINY_PNG_DATA_URL},
+ },
+ ],
+ }
+ ],
+ "max_tokens": 16,
+ "timeout": 60,
+ }
+ if api_base:
+ kwargs["api_base"] = api_base
+ if cfg.get("litellm_params"):
+ # Strip pricing keys — they're tracking-only and confuse some
+ # provider validators (e.g. azure/openai reject unknown kwargs
+ # in strict mode).
+ merged = {
+ k: v
+ for k, v in dict(cfg["litellm_params"]).items()
+ if k
+ not in {
+ "input_cost_per_token",
+ "output_cost_per_token",
+ "input_cost_per_pixel",
+ "output_cost_per_pixel",
+ }
+ }
+ kwargs.update(merged)
+ try:
+ resp = await litellm.acompletion(**kwargs)
+ except Exception as exc:
+ return False, f"{type(exc).__name__}: {exc}"
+ text = resp.choices[0].message.content if resp.choices else ""
+ return True, f"got reply ({(text or '').strip()[:40]!r})"
+
+
+# Gemini image models occasionally return zero-length ``data`` for the
+# minimal "red dot on white" prompt (provider-side safety / empty-output
+# quirk reproducible against ``google/gemini-2.5-flash-image`` even when
+# the request itself succeeds). Use a more naturalistic prompt and
+# retry once with a different one before giving up.
+_IMAGE_GEN_PROMPTS: tuple[str, ...] = (
+ "A simple icon of a coffee cup, flat illustration",
+ "A small green leaf on a white background",
+)
+
+
+async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
+ """Generate one tiny image to verify the deployment is reachable."""
+ from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
+
+ if cfg.get("custom_provider"):
+ prefix = cfg["custom_provider"]
+ else:
+ prefix = _PROVIDER_PREFIX_MAP.get(
+ (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
+ )
+ model_string = f"{prefix}/{cfg['model_name']}"
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=prefix,
+ config_api_base=cfg.get("api_base") or None,
+ )
+ base_kwargs: dict[str, Any] = {
+ "model": model_string,
+ "api_key": cfg.get("api_key"),
+ "n": 1,
+ "size": "1024x1024",
+ "timeout": 120,
+ }
+ if api_base:
+ base_kwargs["api_base"] = api_base
+ if cfg.get("api_version"):
+ base_kwargs["api_version"] = cfg["api_version"]
+ if cfg.get("litellm_params"):
+ base_kwargs.update(
+ {
+ k: v
+ for k, v in dict(cfg["litellm_params"]).items()
+ if k
+ not in {
+ "input_cost_per_token",
+ "output_cost_per_token",
+ "input_cost_per_pixel",
+ "output_cost_per_pixel",
+ }
+ }
+ )
+
+ last_note = ""
+ for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1):
+ try:
+ resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs)
+ except Exception as exc:
+ last_note = f"{type(exc).__name__}: {exc}"
+ continue
+ data_count = len(getattr(resp, "data", None) or [])
+ if data_count > 0:
+ return True, (
+ f"received {data_count} image(s) on attempt {attempt} "
+ f"(prompt={prompt!r})"
+ )
+ last_note = (
+ f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})"
+ )
+ return False, last_note
+
+
+# ---------------------------------------------------------------------------
+# Probe drivers
+# ---------------------------------------------------------------------------
+
+
+def _is_or_dynamic(cfg: dict) -> bool:
+ return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER))
+
+
+async def probe_chat_configs(report: Report, *, live: bool) -> None:
+ print("\n[chat configs from global_llm_configs (YAML-static)]")
+ for cfg in config.GLOBAL_LLM_CONFIGS:
+ # Skip OR dynamic entries here — handled in the OR section so
+ # the YAML / OR split stays clear in the report.
+ if _is_or_dynamic(cfg):
+ continue
+ result = ProbeResult(
+ label=str(cfg.get("name") or cfg.get("model_name")),
+ surface="chat-yaml",
+ config_id=cfg.get("id"),
+ )
+ cap_ok, cap_note = _probe_chat_capability(cfg)
+ result.capability_ok = cap_ok
+ result.capability_note = cap_note
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await _live_chat_image_call(cfg)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+async def probe_vision_configs(report: Report, *, live: bool) -> None:
+ print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
+ for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
+ if _is_or_dynamic(cfg):
+ continue
+ result = ProbeResult(
+ label=str(cfg.get("name") or cfg.get("model_name")),
+ surface="vision",
+ config_id=cfg.get("id"),
+ )
+ # For vision configs, capability is implied — they're in the
+ # dedicated vision pool. Run the same resolver to flag any
+ # surprise disagreement.
+ cap_ok, cap_note = _probe_chat_capability(cfg)
+ result.capability_ok = cap_ok
+ result.capability_note = cap_note
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await _live_chat_image_call(cfg)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
+ print(
+ "\n[image generation configs from global_image_generation_configs (YAML-static)]"
+ )
+ for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
+ if _is_or_dynamic(cfg):
+ continue
+ result = ProbeResult(
+ label=str(cfg.get("name") or cfg.get("model_name")),
+ surface="image-gen",
+ config_id=cfg.get("id"),
+ )
+ # Image gen configs don't have a "supports_image_input" flag;
+ # the catalog tracks output, not input. Mark capability as None
+ # (skip) for the report.
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await _live_image_gen_call(cfg)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
+ """Sample one chat (vision-capable), one vision, one image-gen model
+ from the live OpenRouter catalogue. Doesn't iterate the full pool
+ (would be hundreds of probes); just validates the integration end-
+ to-end on a representative model from each surface."""
+ print("\n[OpenRouter integration: sampled probes]")
+ settings = config.OPENROUTER_INTEGRATION_SETTINGS
+ if not settings:
+ report.add(
+ ProbeResult(
+ label="OpenRouter integration",
+ surface="openrouter",
+ config_id="settings",
+ capability_ok=None,
+ capability_note="openrouter_integration disabled in YAML — skipping",
+ live_ok=None,
+ )
+ )
+ return
+
+ service = OpenRouterIntegrationService.get_instance()
+ or_chat = [
+ c
+ for c in config.GLOBAL_LLM_CONFIGS
+ if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
+ ]
+ or_vision = [
+ c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
+ ]
+ or_image_gen = [
+ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
+ ]
+
+ # Pick one representative per provider family per surface so a single
+ # broken vendor (e.g. Anthropic key revoked, Google quota exceeded)
+ # surfaces independently of the others. Each needle matches the
+ # OpenRouter ``model_name`` prefix; the first match wins.
+ def _pick_first(pool: list[dict], needle: str) -> dict | None:
+ for c in pool:
+ if (c.get("model_name") or "").lower().startswith(needle):
+ return c
+ return None
+
+ chat_picks = [
+ ("or-chat", _pick_first(or_chat, "openai/gpt-4o")),
+ ("or-chat", _pick_first(or_chat, "anthropic/claude")),
+ ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
+ ]
+ vision_picks = [
+ ("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
+ ("or-vision", _pick_first(or_vision, "anthropic/claude")),
+ ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
+ ]
+ image_picks = [
+ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
+ # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
+ # / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match
+ # the actual prefix.
+ ("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")),
+ ]
+
+ print(
+ f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
+ f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
+ )
+
+ for surface, picked in chat_picks + vision_picks + image_picks:
+ if not picked:
+ report.add(
+ ProbeResult(
+ label=f"",
+ surface=surface,
+ config_id="-",
+ capability_ok=None,
+ capability_note="no candidate found in OR catalog",
+ )
+ )
+ continue
+ runner = (
+ _live_image_gen_call if surface == "or-image" else _live_chat_image_call
+ )
+ result = ProbeResult(
+ label=str(picked.get("model_name")),
+ surface=surface,
+ config_id=picked.get("id"),
+ )
+ if surface != "or-image":
+ cap_ok, cap_note = _probe_chat_capability(picked)
+ result.capability_ok = cap_ok
+ result.capability_note = cap_note
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await runner(picked)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+# ---------------------------------------------------------------------------
+# Entry point
+# ---------------------------------------------------------------------------
+
+
+async def main(args: argparse.Namespace) -> int:
+ print("Loaded global configs:")
+ print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
+ print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
+ print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
+ print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
+
+ # Initialize the OpenRouter integration so the catalog is populated
+ # (this is what main.py does at startup). It's idempotent.
+ if config.OPENROUTER_INTEGRATION_SETTINGS:
+ try:
+ from app.config import initialize_openrouter_integration
+
+ initialize_openrouter_integration()
+ except Exception as exc:
+ print(f" WARNING: OpenRouter integration init failed: {exc}")
+
+ print(
+ f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}"
+ )
+
+ report = Report()
+ if not args.skip_chat:
+ await probe_chat_configs(report, live=args.live)
+ if not args.skip_vision:
+ await probe_vision_configs(report, live=args.live)
+ if not args.skip_image_gen:
+ await probe_image_gen_configs(report, live=args.live)
+ if not args.skip_openrouter:
+ await probe_openrouter_catalog(report, live=args.live)
+
+ failed = report.render()
+ return 1 if failed else 0
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--no-live",
+ dest="live",
+ action="store_false",
+ help="Skip live API calls — capability resolver only.",
+ )
+ parser.set_defaults(live=True)
+ parser.add_argument("--skip-chat", action="store_true")
+ parser.add_argument("--skip-vision", action="store_true")
+ parser.add_argument("--skip-image-gen", action="store_true")
+ parser.add_argument("--skip-openrouter", action="store_true")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = _parse_args()
+ sys.exit(asyncio.run(main(args)))
diff --git a/surfsense_backend/tests/integration/chat/__init__.py b/surfsense_backend/tests/integration/chat/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py
new file mode 100644
index 000000000..a5182a978
--- /dev/null
+++ b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py
@@ -0,0 +1,573 @@
+"""Integration tests for the cross-writer integration between the
+streaming chat task and the legacy ``POST /threads/{id}/messages``
+(``append_message``) round-trip.
+
+Two scenarios anchor the contract introduced by the server-side
+persistence rework:
+
+(a) **Tool-heavy turn streamed to completion.**
+
+ Drives :class:`AssistantContentBuilder` with synthetic SSE events
+ that mirror what ``_stream_agent_events`` emits for a turn that
+ interleaves text, reasoning, a tool call (start/delta/available/
+ output), and a final text block. Then runs
+ :func:`finalize_assistant_turn` and asserts:
+
+ * ``new_chat_messages.content`` JSONB matches the
+ ``ContentPart[]`` shape the FE history loader expects, with full
+ ``args``/``argsText``/``result``/``langchainToolCallId`` for the
+ tool call.
+ * Exactly one ``token_usage`` row exists keyed on the assistant
+ ``message_id``.
+
+(b) **Stale FE ``appendMessage`` after server finalize.**
+
+ Verifies the recovery branch of the ``append_message`` route now
+ returns the SERVER's authoritative ``ContentPart[]`` (not the FE's
+ stale payload) when the partial unique index from migration 141
+ blocks the FE's INSERT, and that the ``ON CONFLICT DO NOTHING``
+ clause from migration 142 stops the route from writing a duplicate
+ ``token_usage`` row.
+"""
+
+from __future__ import annotations
+
+import json
+from contextlib import asynccontextmanager
+
+import pytest
+import pytest_asyncio
+from sqlalchemy import func, select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import (
+ ChatVisibility,
+ NewChatMessage,
+ NewChatMessageRole,
+ NewChatThread,
+ SearchSpace,
+ TokenUsage,
+ User,
+)
+from app.routes import new_chat_routes
+from app.services.token_tracking_service import TurnTokenAccumulator
+from app.tasks.chat import persistence as persistence_module
+from app.tasks.chat.content_builder import AssistantContentBuilder
+from app.tasks.chat.persistence import (
+ finalize_assistant_turn,
+ persist_assistant_shell,
+)
+
+pytestmark = pytest.mark.integration
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest_asyncio.fixture
+async def db_thread(
+ db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
+) -> NewChatThread:
+ thread = NewChatThread(
+ title="Test Chat",
+ search_space_id=db_search_space.id,
+ created_by_id=db_user.id,
+ visibility=ChatVisibility.PRIVATE,
+ )
+ db_session.add(thread)
+ await db_session.flush()
+ return thread
+
+
+@pytest.fixture
+def patched_shielded_session(monkeypatch, db_session: AsyncSession):
+ """Route persistence helpers to the test's savepoint-bound session.
+
+ Mirrors the helper from ``test_persistence.py`` so the helpers'
+ internal ``ws.commit()`` / ``ws.rollback()`` resolve to SAVEPOINT
+ operations on the test transaction instead of touching real
+ autocommit boundaries.
+ """
+
+ @asynccontextmanager
+ async def _fake_shielded_session():
+ yield db_session
+
+ monkeypatch.setattr(
+ persistence_module,
+ "shielded_async_session",
+ _fake_shielded_session,
+ )
+ return db_session
+
+
+@pytest.fixture
+def bypass_permission_checks(monkeypatch):
+ """Replace RBAC + thread access checks with no-ops.
+
+ The append_message route under test calls ``check_permission`` and
+ ``check_thread_access``; those rely on a SearchSpaceMembership row
+ that the existing integration fixtures don't create. The contract
+ we want to verify here is the ``IntegrityError`` -> recovery branch,
+ not the RBAC plumbing — so stub them.
+ """
+
+ async def _allow(*_args, **_kwargs):
+ return True
+
+ monkeypatch.setattr(new_chat_routes, "check_permission", _allow)
+ monkeypatch.setattr(new_chat_routes, "check_thread_access", _allow)
+ return None
+
+
+class _FakeRequest:
+ """Minimal Request stand-in used by ``append_message``.
+
+ The route only calls ``await request.json()`` — keep the surface
+ area tight so this doesn't accidentally hide future signature
+ changes that we *would* want to break the test.
+ """
+
+ def __init__(self, body: dict):
+ self._body = body
+
+ async def json(self) -> dict:
+ return self._body
+
+
+def _build_tool_heavy_content() -> list[dict]:
+ """Drive ``AssistantContentBuilder`` through a tool-heavy turn.
+
+ Produces the same ``ContentPart[]`` shape the streaming layer would
+ persist if ``_stream_agent_events`` ran a turn with: opening
+ reasoning -> text -> tool call (input start/delta/available/output)
+ -> closing text. Centralised here so the (a) and (b) scenarios use
+ the same authoritative payload.
+ """
+ builder = AssistantContentBuilder()
+
+ builder.on_reasoning_start("r1")
+ builder.on_reasoning_delta("r1", "Let me look up ")
+ builder.on_reasoning_delta("r1", "the file listing.")
+ builder.on_reasoning_end("r1")
+
+ builder.on_text_start("t1")
+ builder.on_text_delta("t1", "Sure, listing files in ")
+ builder.on_text_delta("t1", "/.")
+ builder.on_text_end("t1")
+
+ builder.on_tool_input_start(
+ "tool_call_ui_1",
+ tool_name="ls",
+ langchain_tool_call_id="lc_call_xyz",
+ )
+ builder.on_tool_input_delta("tool_call_ui_1", '{"path"')
+ builder.on_tool_input_delta("tool_call_ui_1", ': "/"}')
+ builder.on_tool_input_available(
+ "tool_call_ui_1",
+ tool_name="ls",
+ args={"path": "/"},
+ langchain_tool_call_id="lc_call_xyz",
+ )
+ builder.on_tool_output_available(
+ "tool_call_ui_1",
+ output={"files": ["a.txt", "b.txt"]},
+ langchain_tool_call_id="lc_call_xyz",
+ )
+
+ builder.on_text_start("t2")
+ builder.on_text_delta("t2", "Found 2 files: a.txt and b.txt.")
+ builder.on_text_end("t2")
+
+ return builder.snapshot()
+
+
+def _accumulator_with_one_call() -> TurnTokenAccumulator:
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=200,
+ completion_tokens=80,
+ total_tokens=280,
+ cost_micros=22222,
+ )
+ return acc
+
+
+# ---------------------------------------------------------------------------
+# (a) Tool-heavy stream finalize
+# ---------------------------------------------------------------------------
+
+
+class TestToolHeavyTurnFinalize:
+ async def test_full_tool_call_persisted_and_one_token_usage_row(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ ):
+ """End-to-end seam: builder snapshot -> finalize -> DB row.
+
+ Matches the production flow's *content* invariant: whatever
+ ``AssistantContentBuilder.snapshot()`` produces is what the
+ streaming layer hands to ``finalize_assistant_turn``, so this
+ test catches any drift between the JSONB shape the builder
+ emits and the one the FE history loader expects.
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:tool_heavy"
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+
+ snapshot = _build_tool_heavy_content()
+ # Sanity-check the snapshot before we hand it to the DB so a
+ # builder regression surfaces here, not deep inside an opaque
+ # JSONB diff.
+ assert any(p.get("type") == "reasoning" for p in snapshot)
+ text_parts = [p for p in snapshot if p.get("type") == "text"]
+ assert len(text_parts) == 2
+ tool_parts = [p for p in snapshot if p.get("type") == "tool-call"]
+ assert len(tool_parts) == 1
+ tool_part = tool_parts[0]
+ assert tool_part["toolCallId"] == "tool_call_ui_1"
+ assert tool_part["toolName"] == "ls"
+ assert tool_part["args"] == {"path": "/"}
+ # ``argsText`` ends up as the pretty-printed final args (the
+ # ``tool-input-available`` event replaces the streamed deltas
+ # with ``json.dumps(args, indent=2)`` to match the FE's
+ # post-stream rendering).
+ assert tool_part["argsText"] == '{\n "path": "/"\n}'
+ assert tool_part["result"] == {"files": ["a.txt", "b.txt"]}
+ # ``langchainToolCallId`` is the agent-side correlation id used
+ # by the regenerate path; a missing one breaks
+ # edit-from-tool-call later.
+ assert tool_part["langchainToolCallId"] == "lc_call_xyz"
+
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=snapshot,
+ accumulator=_accumulator_with_one_call(),
+ )
+
+ # ``content`` must round-trip byte-for-byte through the JSONB
+ # column. SQLAlchemy doesn't auto-refresh the row that survived
+ # the savepoint commit, so refresh explicitly.
+ row = await db_session.get(NewChatMessage, msg_id)
+ await db_session.refresh(row)
+
+ # The history loader reads ``content`` straight into the FE's
+ # parts array, so a strict equality comparison is the right
+ # invariant here.
+ assert row.content == snapshot
+ # Tool-call parts must JSON-serialise cleanly — nothing in
+ # ``args`` / ``argsText`` / ``result`` should accidentally be a
+ # non-JSON-safe value (datetime, set, custom class).
+ assert json.dumps(row.content)
+
+ usage_count = (
+ await db_session.execute(
+ select(func.count())
+ .select_from(TokenUsage)
+ .where(TokenUsage.message_id == msg_id)
+ )
+ ).scalar_one()
+ assert usage_count == 1
+
+ usage = (
+ await db_session.execute(
+ select(TokenUsage).where(TokenUsage.message_id == msg_id)
+ )
+ ).scalar_one()
+ assert usage.usage_type == "chat"
+ assert usage.prompt_tokens == 200
+ assert usage.completion_tokens == 80
+ assert usage.total_tokens == 280
+ assert usage.cost_micros == 22222
+ assert usage.thread_id == thread_id
+ assert usage.search_space_id == search_space_id
+
+
+# ---------------------------------------------------------------------------
+# (b) FE appendMessage after server finalize
+# ---------------------------------------------------------------------------
+
+
+class TestAppendMessageRecoveryAfterFinalize:
+ async def test_returns_server_content_and_does_not_duplicate_token_usage(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ bypass_permission_checks,
+ ):
+ """FE's stale ``appendMessage`` after server finalize.
+
+ The frontend used to be the authoritative writer for assistant
+ ``content``. Now the server is. When the legacy FE round-trip
+ fires *after* the server has already finalized:
+
+ * the route's INSERT trips the (thread_id, turn_id, role)
+ partial unique index from migration 141,
+ * the recovery branch fetches the existing row and returns
+ *its* ``content`` — discarding the FE payload — so the
+ history loader reads the rich server payload (full tool
+ args, argsText, langchainToolCallId, etc.) on next page
+ reload,
+ * the route's optional ``token_usage`` insert is keyed on the
+ partial unique index from migration 142 so it silently
+ no-ops if the server already wrote one.
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:fe_late_append"
+
+ # Step 1: server stream completes. Server-built rich content is
+ # finalized.
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+
+ server_content = _build_tool_heavy_content()
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=server_content,
+ accumulator=_accumulator_with_one_call(),
+ )
+
+ # Step 2: simulate the legacy FE ``appendMessage`` round-trip
+ # arriving with stale, lossy content (missing tool args, etc.)
+ # plus a ``token_usage`` body.
+ fe_stale_content = [
+ {"type": "text", "text": "Found 2 files: a.txt and b.txt."},
+ ]
+ fe_request_body = {
+ "role": "assistant",
+ "content": fe_stale_content,
+ "turn_id": turn_id,
+ "token_usage": {
+ "prompt_tokens": 999,
+ "completion_tokens": 999,
+ "total_tokens": 1998,
+ "cost_micros": 88888,
+ "usage": {"any": "thing"},
+ "call_details": {"calls": []},
+ },
+ }
+ request = _FakeRequest(fe_request_body)
+
+ # ``db_user`` is bound to ``db_session``. The route's
+ # IntegrityError branch calls ``session.rollback()``, which
+ # expires every ORM row attached to the session including this
+ # user — historically causing ``user.id`` to lazy-load
+ # out-of-greenlet and crash the request with ``MissingGreenlet``
+ # (observed in production logs at /api/v1/threads/531/messages
+ # 2026-05-04). The route now captures ``user.id`` to a primitive
+ # UUID at the top of the handler, so the rollback can't reach
+ # it. Pass the *attached* user here on purpose — that's the
+ # production scenario, and this test is the regression guard
+ # against that bug returning.
+ response = await new_chat_routes.append_message(
+ thread_id=thread_id,
+ request=request,
+ session=db_session,
+ user=db_user,
+ )
+
+ # Response must echo the SERVER's rich payload, not the FE's
+ # stale snapshot. This is the user-visible part of the
+ # contract: history reload + ThreadHistoryAdapter.append both
+ # read from the same authoritative source.
+ assert response.id == msg_id
+ assert response.role == NewChatMessageRole.ASSISTANT
+ assert response.turn_id == turn_id
+ assert response.content == server_content
+ assert response.content != fe_stale_content
+
+ # The on-disk row must agree with the response.
+ row = await db_session.get(NewChatMessage, msg_id)
+ await db_session.refresh(row)
+ assert row.content == server_content
+
+ # ``token_usage``: exactly one row, with the *server's* values
+ # (the FE's much larger token counts must not have overwritten
+ # them).
+ usage_count = (
+ await db_session.execute(
+ select(func.count())
+ .select_from(TokenUsage)
+ .where(TokenUsage.message_id == msg_id)
+ )
+ ).scalar_one()
+ assert usage_count == 1
+
+ usage = (
+ await db_session.execute(
+ select(TokenUsage).where(TokenUsage.message_id == msg_id)
+ )
+ ).scalar_one()
+ assert usage.cost_micros == 22222 # Server's value, not 88888
+ assert usage.total_tokens == 280 # Server's value, not 1998
+
+ async def test_legacy_fe_first_appendmessage_then_server_no_dupe(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ bypass_permission_checks,
+ ):
+ """Inverse race: legacy FE writes first, server finalize second.
+
+ Some clients still post ``appendMessage`` before the streaming
+ ``finally`` runs. The contract is symmetric: whichever writer
+ loses the (thread_id, turn_id, role) race silently lets the
+ winner keep its content. In particular the *server's*
+ finalize must NOT raise — it must look up the existing row and
+ UPDATE its content with the server-built payload (which is
+ always richer/more authoritative than whatever the FE
+ snapshot held).
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:fe_first"
+
+ # Step 1: legacy FE appendMessage lands first. No prior shell
+ # row exists; the route does the INSERT itself.
+ fe_request_body = {
+ "role": "assistant",
+ "content": [{"type": "text", "text": "early FE write"}],
+ "turn_id": turn_id,
+ }
+ fe_response = await new_chat_routes.append_message(
+ thread_id=thread_id,
+ request=_FakeRequest(fe_request_body),
+ session=db_session,
+ user=db_user,
+ )
+ assert fe_response.role == NewChatMessageRole.ASSISTANT
+
+ # Step 2: server stream's persist_assistant_shell now races
+ # behind. It must adopt the existing row id, not raise.
+ adopted_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert adopted_id == fe_response.id
+
+ # Step 3: server finalize then overwrites the FE's stub with
+ # the rich content (which is the correct, more authoritative
+ # payload).
+ server_content = _build_tool_heavy_content()
+ await finalize_assistant_turn(
+ message_id=adopted_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=server_content,
+ accumulator=_accumulator_with_one_call(),
+ )
+
+ # Final state: one row, server content, one token_usage row.
+ msg_count = (
+ await db_session.execute(
+ select(func.count())
+ .select_from(NewChatMessage)
+ .where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.ASSISTANT,
+ )
+ )
+ ).scalar_one()
+ assert msg_count == 1
+
+ row = await db_session.get(NewChatMessage, adopted_id)
+ await db_session.refresh(row)
+ assert row.content == server_content
+
+ usage_count = (
+ await db_session.execute(
+ select(func.count())
+ .select_from(TokenUsage)
+ .where(TokenUsage.message_id == adopted_id)
+ )
+ ).scalar_one()
+ assert usage_count == 1
+
+ async def test_appendmessage_without_turn_id_legacy_400(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ bypass_permission_checks,
+ ):
+ """Defensive: a bare appendMessage with no turn_id and no
+ existing row is just a normal INSERT — must succeed. But if a
+ row with the same role already exists in this thread *without*
+ turn_id collisions, the route should fall through to the
+ legacy 400 path on a foreign-key / unrelated IntegrityError
+ (we don't ship that bug today, but pin the behaviour so a
+ future schema change can't silently regress it).
+ """
+ thread_id = db_thread.id
+
+ # Bare appendMessage with no turn_id — should just succeed
+ # without invoking the recovery branch.
+ ok_response = await new_chat_routes.append_message(
+ thread_id=thread_id,
+ request=_FakeRequest(
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "hi"}],
+ }
+ ),
+ session=db_session,
+ user=db_user,
+ )
+ assert ok_response.role == NewChatMessageRole.USER
+ assert ok_response.turn_id is None
+
+ # Sanity: the route did NOT silently swallow the missing
+ # turn_id by routing through the unique-index recovery branch
+ # — it took the happy path.
+ msg_count = (
+ await db_session.execute(
+ select(func.count())
+ .select_from(NewChatMessage)
+ .where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.role == NewChatMessageRole.USER,
+ )
+ )
+ ).scalar_one()
+ assert msg_count == 1
diff --git a/surfsense_backend/tests/integration/chat/test_message_id_sse.py b/surfsense_backend/tests/integration/chat/test_message_id_sse.py
new file mode 100644
index 000000000..8fc935eaa
--- /dev/null
+++ b/surfsense_backend/tests/integration/chat/test_message_id_sse.py
@@ -0,0 +1,332 @@
+"""Integration tests for the SSE-based message ID handshake.
+
+The streaming generators (``stream_new_chat`` / ``stream_resume_chat``)
+emit two new events after their respective persistence helpers resolve
+the canonical ``new_chat_messages.id``:
+
+* ``data-user-message-id`` — emitted only by ``stream_new_chat``,
+ AFTER ``persist_user_turn`` and BEFORE any LLM streaming.
+* ``data-assistant-message-id`` — emitted by both
+ ``stream_new_chat`` and ``stream_resume_chat``, AFTER
+ ``persist_assistant_shell`` and BEFORE any LLM streaming.
+
+The frontend renames its optimistic ``msg-user-XXX`` /
+``msg-assistant-XXX`` placeholder ids to ``msg-{db_id}`` upon receiving
+these events. This test suite anchors three contracts:
+
+1. ``format_data`` produces SSE bytes in the precise shape
+ ``data: {"type":"data-","data":{...}}\\n\\n`` that the FE's
+ ``readSSEStream`` consumer parses (matches ``surfsense_web/lib/chat/streaming-state.ts``).
+2. The ``message_id`` carried in the SSE payload exactly equals the
+ primary key the persistence helper inserted into
+ ``new_chat_messages`` — so the FE rename produces ``msg-{real_pk}``,
+ which in turn unlocks DB-id-gated UI (comments, edit-from-message).
+3. The same ``message_id`` is used for the ``token_usage.message_id``
+ foreign key, so ``finalize_assistant_turn``'s row binds correctly.
+
+Direct end-to-end testing of ``stream_new_chat`` requires a fully
+mocked agent + LLM stack (out-of-scope here); those flows are covered
+by the harness-driven integration tests under
+``tests/integration/agents/new_chat/`` plus the assertion in
+``test_persistence.py`` that the helpers themselves return ``int``
+ids. The contracts above close the remaining gap between the persist
+helpers and the bytes that ship to the FE.
+"""
+
+from __future__ import annotations
+
+import json
+from contextlib import asynccontextmanager
+
+import pytest
+import pytest_asyncio
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import (
+ ChatVisibility,
+ NewChatMessage,
+ NewChatMessageRole,
+ NewChatThread,
+ SearchSpace,
+ User,
+)
+from app.services.new_streaming_service import VercelStreamingService
+from app.tasks.chat import persistence as persistence_module
+from app.tasks.chat.persistence import (
+ persist_assistant_shell,
+ persist_user_turn,
+)
+
+pytestmark = pytest.mark.integration
+
+
+# ---------------------------------------------------------------------------
+# Fixtures (mirror test_persistence.py)
+# ---------------------------------------------------------------------------
+
+
+@pytest_asyncio.fixture
+async def db_thread(
+ db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
+) -> NewChatThread:
+ thread = NewChatThread(
+ title="Test Chat",
+ search_space_id=db_search_space.id,
+ created_by_id=db_user.id,
+ visibility=ChatVisibility.PRIVATE,
+ )
+ db_session.add(thread)
+ await db_session.flush()
+ return thread
+
+
+@pytest.fixture
+def patched_shielded_session(monkeypatch, db_session: AsyncSession):
+ """Route persistence helpers to the test's savepoint-bound session."""
+
+ @asynccontextmanager
+ async def _fake_shielded_session():
+ yield db_session
+
+ monkeypatch.setattr(
+ persistence_module,
+ "shielded_async_session",
+ _fake_shielded_session,
+ )
+ return db_session
+
+
+# ---------------------------------------------------------------------------
+# (1) SSE byte-shape contract
+# ---------------------------------------------------------------------------
+
+
+def _parse_sse_data_line(blob: str) -> dict:
+ """Unwrap a single ``data: \\n\\n`` SSE frame.
+
+ Raises if there's more than one frame or the prefix is wrong — keeps
+ the parser strict so a regression in ``format_data`` produces a
+ test failure here, not in a downstream consumer.
+ """
+ assert blob.endswith("\n\n"), f"missing terminator: {blob!r}"
+ line = blob.removesuffix("\n\n")
+ assert line.startswith("data: "), f"missing data prefix: {line!r}"
+ return json.loads(line.removeprefix("data: "))
+
+
+class TestSSEByteShape:
+ def test_data_user_message_id_byte_shape(self):
+ """``format_data("user-message-id", {...})`` must produce the
+ exact wire format the FE's
+ ``readSSEStream`` -> ``data-user-message-id`` case parses.
+ """
+ svc = VercelStreamingService()
+ blob = svc.format_data(
+ "user-message-id",
+ {"message_id": 1843, "turn_id": "533:1762900000000"},
+ )
+ envelope = _parse_sse_data_line(blob)
+ assert envelope == {
+ "type": "data-user-message-id",
+ "data": {"message_id": 1843, "turn_id": "533:1762900000000"},
+ }
+
+ def test_data_assistant_message_id_byte_shape(self):
+ svc = VercelStreamingService()
+ blob = svc.format_data(
+ "assistant-message-id",
+ {"message_id": 1844, "turn_id": "533:1762900000000"},
+ )
+ envelope = _parse_sse_data_line(blob)
+ assert envelope == {
+ "type": "data-assistant-message-id",
+ "data": {"message_id": 1844, "turn_id": "533:1762900000000"},
+ }
+
+
+# ---------------------------------------------------------------------------
+# (2) Helper-id <-> DB-pk coherence
+# ---------------------------------------------------------------------------
+
+
+class TestHandshakeIdMatchesDB:
+ """The SSE handshake's correctness hinges on the integer in
+ ``data-{user,assistant}-message-id`` being the EXACT primary key
+ the persistence helper inserted. If they ever diverge, the FE
+ rename produces ``msg-{wrong_id}``, comments break (regex match
+ fails), and downstream features (edit, regenerate) target the
+ wrong row. Anchor it here.
+ """
+
+ async def test_user_message_id_matches_new_chat_messages_pk(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:9000"
+
+ # The streaming generator passes this same value into
+ # ``streaming_service.format_data("user-message-id", {...})``.
+ msg_id_from_helper = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="hello",
+ )
+ assert isinstance(msg_id_from_helper, int)
+
+ # Look up the row the helper inserted via
+ # ``(thread_id, turn_id, role)`` — the same composite the FE
+ # uses to identify a turn — and confirm the PK matches.
+ row = (
+ await db_session.execute(
+ select(NewChatMessage).where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.USER,
+ )
+ )
+ ).scalar_one()
+ assert row.id == msg_id_from_helper
+
+ # The byte-stream the FE actually receives — confirms the
+ # round-trip from the helper return value to the SSE payload.
+ svc = VercelStreamingService()
+ envelope = _parse_sse_data_line(
+ svc.format_data(
+ "user-message-id",
+ {"message_id": msg_id_from_helper, "turn_id": turn_id},
+ )
+ )
+ assert envelope["data"]["message_id"] == row.id
+
+ async def test_assistant_message_id_matches_new_chat_messages_pk(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:9100"
+
+ msg_id_from_helper = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert isinstance(msg_id_from_helper, int)
+
+ row = (
+ await db_session.execute(
+ select(NewChatMessage).where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.ASSISTANT,
+ )
+ )
+ ).scalar_one()
+ assert row.id == msg_id_from_helper
+
+ svc = VercelStreamingService()
+ envelope = _parse_sse_data_line(
+ svc.format_data(
+ "assistant-message-id",
+ {"message_id": msg_id_from_helper, "turn_id": turn_id},
+ )
+ )
+ assert envelope["data"]["message_id"] == row.id
+
+ async def test_handshake_ids_for_full_turn_are_distinct_and_paired(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ """Sanity: a full new-chat turn's two SSE events carry two
+ DIFFERENT ids (user row PK ≠ assistant row PK), both anchored
+ to the SAME ``turn_id`` in the DB. This pairing is what
+ ``finalize_assistant_turn`` and the regenerate / edit flows
+ rely on.
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:9200"
+
+ user_msg_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="hi",
+ )
+ assistant_msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert user_msg_id is not None and assistant_msg_id is not None
+ assert user_msg_id != assistant_msg_id
+
+ rows = (
+ (
+ await db_session.execute(
+ select(NewChatMessage)
+ .where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ )
+ .order_by(NewChatMessage.id)
+ )
+ )
+ .scalars()
+ .all()
+ )
+ assert len(rows) == 2
+ ids_by_role = {r.role: r.id for r in rows}
+ assert ids_by_role[NewChatMessageRole.USER] == user_msg_id
+ assert ids_by_role[NewChatMessageRole.ASSISTANT] == assistant_msg_id
+
+
+# ---------------------------------------------------------------------------
+# (3) Parse helpers used by the FE — sanity-check our payload shape
+# ---------------------------------------------------------------------------
+
+
+class TestPayloadShapeMatchesFEReader:
+ """The FE's ``readStreamedMessageId`` (in
+ ``surfsense_web/lib/chat/stream-side-effects.ts``) requires:
+
+ * ``message_id`` is a ``number`` (rejects null / string / NaN).
+ * ``turn_id`` is an optional non-empty string (else it's coerced
+ to ``null``).
+
+ These tests exercise the BE side of that contract by inspecting
+ ``format_data`` output shapes that the FE consumes verbatim.
+ """
+
+ def test_message_id_is_serialised_as_a_json_number(self):
+ svc = VercelStreamingService()
+ envelope = _parse_sse_data_line(
+ svc.format_data("user-message-id", {"message_id": 42, "turn_id": "t"})
+ )
+ assert isinstance(envelope["data"]["message_id"], int)
+ assert envelope["data"]["message_id"] == 42
+
+ def test_turn_id_round_trips_as_string(self):
+ svc = VercelStreamingService()
+ # The actual format used in production: f"{chat_id}:{int(time.time()*1000)}"
+ production_turn_id = "533:1762900000000"
+ envelope = _parse_sse_data_line(
+ svc.format_data(
+ "assistant-message-id",
+ {"message_id": 1, "turn_id": production_turn_id},
+ )
+ )
+ assert envelope["data"]["turn_id"] == production_turn_id
diff --git a/surfsense_backend/tests/integration/chat/test_persistence.py b/surfsense_backend/tests/integration/chat/test_persistence.py
new file mode 100644
index 000000000..66a04772e
--- /dev/null
+++ b/surfsense_backend/tests/integration/chat/test_persistence.py
@@ -0,0 +1,747 @@
+"""Integration tests for ``app.tasks.chat.persistence``.
+
+Verifies the DB-side guarantees the streaming chat task relies on:
+
+* ``persist_assistant_shell`` is idempotent against the
+ ``(thread_id, turn_id, ASSISTANT)`` partial unique index from
+ migration 141. Two calls with the same ``turn_id`` return the SAME
+ ``message_id`` and never create a duplicate ``new_chat_messages`` row.
+* ``finalize_assistant_turn`` writes a status-marker payload when given
+ empty content, never raises, and is safe to call twice on the same
+ ``message_id`` — the partial unique index from migration 142
+ (``uq_token_usage_message_id``) prevents the second insert from
+ producing a duplicate ``token_usage`` row.
+* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer
+ race where ``finalize_assistant_turn`` and the ``append_message``
+ recovery branch both target the same ``message_id``.
+
+All tests run inside the conftest's outer-transaction-with-savepoint
+fixture so commits inside the helpers (which open their own
+``shielded_async_session``) are released as savepoints and rolled back
+at test end. We monkey-patch ``shielded_async_session`` to yield the
+same pooled test session so the integration transaction stays
+in-scope.
+"""
+
+from __future__ import annotations
+
+from contextlib import asynccontextmanager
+
+import pytest
+import pytest_asyncio
+from sqlalchemy import func, select
+from sqlalchemy.dialects.postgresql import insert as pg_insert
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import (
+ ChatVisibility,
+ NewChatMessage,
+ NewChatMessageRole,
+ NewChatThread,
+ SearchSpace,
+ TokenUsage,
+ User,
+)
+from app.services.token_tracking_service import TurnTokenAccumulator
+from app.tasks.chat import persistence as persistence_module
+from app.tasks.chat.persistence import (
+ finalize_assistant_turn,
+ persist_assistant_shell,
+ persist_user_turn,
+)
+
+pytestmark = pytest.mark.integration
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest_asyncio.fixture
+async def db_thread(
+ db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
+) -> NewChatThread:
+ thread = NewChatThread(
+ title="Test Chat",
+ search_space_id=db_search_space.id,
+ created_by_id=db_user.id,
+ visibility=ChatVisibility.PRIVATE,
+ )
+ db_session.add(thread)
+ await db_session.flush()
+ return thread
+
+
+@pytest.fixture
+def patched_shielded_session(monkeypatch, db_session: AsyncSession):
+ """Route persistence helpers to the test's savepoint-bound session.
+
+ The persistence helpers use ``async with shielded_async_session() as
+ ws`` and call ``ws.commit()`` internally. Inside the conftest's
+ ``join_transaction_mode="create_savepoint"`` setup, those commits
+ release a SAVEPOINT instead of committing the outer transaction —
+ so the test session can see helper-staged rows immediately and the
+ outer rollback at end of test wipes them.
+ """
+
+ @asynccontextmanager
+ async def _fake_shielded_session():
+ yield db_session
+ # Do NOT close — the outer fixture owns the session lifecycle.
+
+ monkeypatch.setattr(
+ persistence_module,
+ "shielded_async_session",
+ _fake_shielded_session,
+ )
+ return db_session
+
+
+def _accumulator_with_one_call() -> TurnTokenAccumulator:
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=100,
+ completion_tokens=50,
+ total_tokens=150,
+ cost_micros=12345,
+ )
+ return acc
+
+
+async def _count_assistant_rows(
+ session: AsyncSession, thread_id: int, turn_id: str
+) -> int:
+ result = await session.execute(
+ select(func.count())
+ .select_from(NewChatMessage)
+ .where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.ASSISTANT,
+ )
+ )
+ return int(result.scalar_one())
+
+
+async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int:
+ result = await session.execute(
+ select(func.count())
+ .select_from(TokenUsage)
+ .where(TokenUsage.message_id == message_id)
+ )
+ return int(result.scalar_one())
+
+
+# ---------------------------------------------------------------------------
+# persist_assistant_shell
+# ---------------------------------------------------------------------------
+
+
+class TestPersistAssistantShell:
+ async def test_first_call_inserts_empty_shell_and_returns_id(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ # Capture primitive ids before any persistence helper runs:
+ # the helpers commit/rollback the shared test session, which
+ # can detach ORM rows mid-test.
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:1000"
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None and isinstance(msg_id, int)
+
+ row = await db_session.get(NewChatMessage, msg_id)
+ assert row is not None
+ assert row.thread_id == thread_id
+ assert row.role == NewChatMessageRole.ASSISTANT
+ assert row.turn_id == turn_id
+ # Empty shell payload — finalize_assistant_turn overwrites later.
+ assert row.content == [{"type": "text", "text": ""}]
+
+ async def test_second_call_with_same_turn_id_returns_same_id(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ # Capture primitive ids before any persistence helper runs:
+ # the helpers commit/rollback the shared test session, which
+ # can detach ORM rows mid-test.
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:2000"
+
+ first_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ second_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+
+ assert first_id is not None
+ assert first_id == second_id
+ # Exactly one row in the DB for this turn.
+ assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1
+
+ async def test_missing_turn_id_returns_none(
+ self,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id="",
+ )
+ assert msg_id is None
+
+ async def test_after_persist_user_turn_resolves_assistant_id(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:3000"
+
+ # The streaming layer always calls persist_user_turn first, so
+ # smoke-test the canonical sequence.
+ user_msg_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="hello",
+ )
+ assert isinstance(user_msg_id, int)
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+ # User row + assistant shell row = 2 rows for this turn.
+ result = await db_session.execute(
+ select(func.count())
+ .select_from(NewChatMessage)
+ .where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ )
+ )
+ assert result.scalar_one() == 2
+
+ async def test_double_call_with_same_turn_id_uses_on_conflict(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ """Verifies the ON CONFLICT DO NOTHING path on the assistant
+ shell does not raise ``IntegrityError`` even when the second
+ writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id``
+ already covers the same-id semantics; this test additionally
+ asserts neither call raises so the debugger's
+ ``raise-on-IntegrityError`` setting won't pause the streaming
+ path under contention.
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:3500"
+
+ # Both calls go through ``pg_insert(...).on_conflict_do_nothing``;
+ # the second one returns RETURNING=∅ and falls into the SELECT
+ # branch. Neither path raises.
+ first_id = await persist_assistant_shell(
+ chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
+ )
+ second_id = await persist_assistant_shell(
+ chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
+ )
+ assert first_id is not None
+ assert first_id == second_id
+
+
+# ---------------------------------------------------------------------------
+# persist_user_turn
+# ---------------------------------------------------------------------------
+
+
+class TestPersistUserTurn:
+ async def test_returns_message_id_on_first_insert(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:8000"
+
+ msg_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="hello",
+ )
+ assert isinstance(msg_id, int) and msg_id > 0
+
+ row = await db_session.get(NewChatMessage, msg_id)
+ assert row is not None
+ assert row.thread_id == thread_id
+ assert row.role == NewChatMessageRole.USER
+ assert row.turn_id == turn_id
+ assert row.content == [{"type": "text", "text": "hello"}]
+
+ async def test_returns_existing_id_on_conflict(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:8100"
+
+ first_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="hello",
+ )
+ # Second call simulates a legacy FE ``appendMessage`` racing the
+ # SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB
+ # level, the helper recovers the existing id via SELECT, and
+ # crucially does NOT raise ``IntegrityError`` (the debugger
+ # would otherwise pause on it).
+ second_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="ignored on conflict",
+ )
+ assert first_id is not None
+ assert first_id == second_id
+
+ # Exactly one user row for this turn.
+ count = await db_session.execute(
+ select(func.count())
+ .select_from(NewChatMessage)
+ .where(
+ NewChatMessage.thread_id == thread_id,
+ NewChatMessage.turn_id == turn_id,
+ NewChatMessage.role == NewChatMessageRole.USER,
+ )
+ )
+ assert count.scalar_one() == 1
+
+ async def test_embeds_mentioned_documents_part(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ """The full ``{id, title, document_type}`` triple forwarded by
+ the FE must round-trip into a single ``mentioned-documents``
+ ContentPart on the persisted user message — the history loader
+ renders the chips on reload from this part directly.
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id = f"{thread_id}:8200"
+
+ mentioned = [
+ {"id": 11, "title": "Alpha", "document_type": "GENERAL"},
+ {"id": 22, "title": "Beta", "document_type": "GENERAL"},
+ ]
+ msg_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ user_query="hello",
+ mentioned_documents=mentioned,
+ )
+ assert isinstance(msg_id, int)
+
+ row = await db_session.get(NewChatMessage, msg_id)
+ assert row is not None
+ # Content is a 2-part list: text + mentioned-documents.
+ assert isinstance(row.content, list)
+ assert row.content[0] == {"type": "text", "text": "hello"}
+ assert row.content[1] == {
+ "type": "mentioned-documents",
+ "documents": [
+ {"id": 11, "title": "Alpha", "document_type": "GENERAL"},
+ {"id": 22, "title": "Beta", "document_type": "GENERAL"},
+ ],
+ }
+
+ async def test_skips_mentioned_documents_when_empty_or_invalid(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ """Empty list and entries missing required fields are dropped;
+ a ``mentioned-documents`` part is only emitted when at least
+ one normalised entry survived.
+ """
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ turn_id_empty = f"{thread_id}:8300"
+ turn_id_invalid = f"{thread_id}:8301"
+
+ msg_id_empty = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id_empty,
+ user_query="hi",
+ mentioned_documents=[],
+ )
+ assert isinstance(msg_id_empty, int)
+ row_empty = await db_session.get(NewChatMessage, msg_id_empty)
+ assert row_empty is not None
+ assert row_empty.content == [{"type": "text", "text": "hi"}]
+
+ # Each entry missing one required field — all skipped.
+ msg_id_invalid = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id_invalid,
+ user_query="hi",
+ mentioned_documents=[
+ {"title": "no id", "document_type": "GENERAL"}, # missing id
+ {"id": 99, "document_type": "GENERAL"}, # missing title
+ {"id": 100, "title": "no type"}, # missing document_type
+ ],
+ )
+ assert isinstance(msg_id_invalid, int)
+ row_invalid = await db_session.get(NewChatMessage, msg_id_invalid)
+ assert row_invalid is not None
+ assert row_invalid.content == [{"type": "text", "text": "hi"}]
+
+ async def test_missing_turn_id_returns_none(
+ self,
+ db_user,
+ db_thread,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+
+ msg_id = await persist_user_turn(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id="",
+ user_query="hello",
+ )
+ assert msg_id is None
+
+
+# ---------------------------------------------------------------------------
+# finalize_assistant_turn
+# ---------------------------------------------------------------------------
+
+
+class TestFinalizeAssistantTurn:
+ async def test_writes_content_and_token_usage(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_uuid = db_user.id
+ user_id_str = str(user_id_uuid)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:4000"
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+
+ rich_content = [
+ {"type": "text", "text": "Hello world"},
+ {
+ "type": "tool-call",
+ "toolCallId": "call_x",
+ "toolName": "ls",
+ "args": {"path": "/"},
+ "argsText": '{\n "path": "/"\n}',
+ "result": {"files": []},
+ "langchainToolCallId": "lc_x",
+ },
+ ]
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=rich_content,
+ accumulator=_accumulator_with_one_call(),
+ )
+
+ row = await db_session.get(NewChatMessage, msg_id)
+ await db_session.refresh(row)
+ assert row.content == rich_content
+
+ # Exactly one token_usage row keyed on this message_id.
+ usage_rows = (
+ (
+ await db_session.execute(
+ select(TokenUsage).where(TokenUsage.message_id == msg_id)
+ )
+ )
+ .scalars()
+ .all()
+ )
+ assert len(usage_rows) == 1
+ usage = usage_rows[0]
+ assert usage.usage_type == "chat"
+ assert usage.prompt_tokens == 100
+ assert usage.completion_tokens == 50
+ assert usage.total_tokens == 150
+ assert usage.cost_micros == 12345
+ assert usage.thread_id == thread_id
+ assert usage.search_space_id == search_space_id
+
+ async def test_empty_content_writes_status_marker(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:5000"
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+
+ # Pure tool-call turn that aborted before any output, or
+ # interrupt before any event arrived — empty list.
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=[],
+ accumulator=None,
+ )
+
+ row = await db_session.get(NewChatMessage, msg_id)
+ await db_session.refresh(row)
+ assert row.content == [{"type": "status", "text": "(no text response)"}]
+
+ async def test_double_call_safe_via_on_conflict(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:6000"
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+
+ first_acc = _accumulator_with_one_call()
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=[{"type": "text", "text": "first finalize"}],
+ accumulator=first_acc,
+ )
+
+ # Simulate a follow-up finalize (e.g., resume retry within the
+ # shielded finally block firing twice). Different content, but
+ # ON CONFLICT DO NOTHING on token_usage means the cost from the
+ # first finalize stays authoritative.
+ second_acc = TurnTokenAccumulator()
+ second_acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=999,
+ completion_tokens=999,
+ total_tokens=1998,
+ cost_micros=99999,
+ )
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=[{"type": "text", "text": "second finalize"}],
+ accumulator=second_acc,
+ )
+
+ # Content was overwritten by the second UPDATE.
+ row = await db_session.get(NewChatMessage, msg_id)
+ await db_session.refresh(row)
+ assert row.content == [{"type": "text", "text": "second finalize"}]
+
+ # But token_usage stayed at exactly one row, preserving the
+ # first finalize's authoritative cost.
+ assert await _count_token_usage_rows(db_session, msg_id) == 1
+ usage = (
+ await db_session.execute(
+ select(TokenUsage).where(TokenUsage.message_id == msg_id)
+ )
+ ).scalar_one()
+ assert usage.cost_micros == 12345 # First finalize's value
+
+ async def test_append_message_style_insert_after_finalize_no_dupe(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ ):
+ """Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``.
+
+ Both target the same ``message_id``; the partial unique index
+ ``uq_token_usage_message_id`` (migration 142) makes the second
+ insert a no-op via ``ON CONFLICT DO NOTHING``.
+ """
+ from sqlalchemy import text as sa_text
+
+ thread_id = db_thread.id
+ user_uuid = db_user.id
+ user_id_str = str(user_uuid)
+ search_space_id = db_search_space.id
+ turn_id = f"{thread_id}:7000"
+
+ msg_id = await persist_assistant_shell(
+ chat_id=thread_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ )
+ assert msg_id is not None
+
+ await finalize_assistant_turn(
+ message_id=msg_id,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id=turn_id,
+ content=[{"type": "text", "text": "from server"}],
+ accumulator=_accumulator_with_one_call(),
+ )
+
+ # Now simulate the FE's append_message branch firing AFTER —
+ # the same INSERT ... ON CONFLICT DO NOTHING shape used by the
+ # route handler, keyed on the migration-142 partial unique
+ # index.
+ late_insert = (
+ pg_insert(TokenUsage)
+ .values(
+ usage_type="chat",
+ prompt_tokens=42,
+ completion_tokens=42,
+ total_tokens=84,
+ cost_micros=1,
+ model_breakdown=None,
+ call_details=None,
+ thread_id=thread_id,
+ message_id=msg_id,
+ search_space_id=search_space_id,
+ user_id=user_uuid,
+ )
+ .on_conflict_do_nothing(
+ index_elements=["message_id"],
+ index_where=sa_text("message_id IS NOT NULL"),
+ )
+ )
+ await db_session.execute(late_insert)
+ await db_session.flush()
+
+ # Still exactly one row, with the original (server) cost value.
+ assert await _count_token_usage_rows(db_session, msg_id) == 1
+ usage = (
+ await db_session.execute(
+ select(TokenUsage).where(TokenUsage.message_id == msg_id)
+ )
+ ).scalar_one()
+ assert usage.cost_micros == 12345
+
+ async def test_helper_never_raises_on_missing_message_id(
+ self,
+ db_session,
+ db_user,
+ db_thread,
+ db_search_space,
+ patched_shielded_session,
+ ):
+ thread_id = db_thread.id
+ user_id_str = str(db_user.id)
+ search_space_id = db_search_space.id
+
+ # message_id that doesn't exist — finalize must log+return,
+ # never raise (called from shielded finally).
+ await finalize_assistant_turn(
+ message_id=999_999_999,
+ chat_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=user_id_str,
+ turn_id="anything",
+ content=[{"type": "text", "text": "x"}],
+ accumulator=_accumulator_with_one_call(),
+ )
+ # If we got here without an exception, the test passes.
+ # Sanity: no token_usage row created (FK to message would have
+ # been rejected anyway, but ON CONFLICT path may swallow
+ # FK errors as well; check directly).
+ assert await _count_token_usage_rows(db_session, 999_999_999) == 0
diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py
index 397b1c787..36fe04aa2 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py
@@ -226,6 +226,31 @@ class TestCompose:
# Default block should NOT be present
assert "" not in prompt
+ def test_provider_hints_render_with_custom_system_instructions(
+ self, fixed_today: datetime
+ ) -> None:
+ """Regression guard for the always-append decision: provider hints
+ append AFTER a custom system prompt.
+
+ Provider hints are stylistic nudges (parallel tool-call rules,
+ formatting guidance, etc.) that help the model regardless of
+ what the system instructions say. Suppressing them when a
+ custom prompt is set would partially defeat the per-family
+ prompt machinery.
+ """
+ prompt = compose_system_prompt(
+ today=fixed_today,
+ custom_system_instructions="You are a custom assistant.",
+ model_name="anthropic/claude-3-5-sonnet",
+ )
+ assert "You are a custom assistant." in prompt
+ assert "" in prompt
+ # The custom prompt must come BEFORE the provider hints so the
+ # user's framing isn't drowned out by the stylistic nudges.
+ assert prompt.index("You are a custom assistant.") < prompt.index(
+ ""
+ )
+
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
new file mode 100644
index 000000000..9b3de2db7
--- /dev/null
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
@@ -0,0 +1,268 @@
+"""Regression tests for the compiled-agent cache.
+
+Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
+build-failure non-caching) and the cache-key signature helpers that
+``create_surfsense_deep_agent`` relies on. The integration with
+``create_surfsense_deep_agent`` is covered separately by the streaming
+contract tests; this module focuses on the primitives so a regression
+in the cache implementation is caught before it reaches the agent
+factory.
+"""
+
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+
+import pytest
+
+from app.agents.new_chat.agent_cache import (
+ flags_signature,
+ reload_for_tests,
+ stable_hash,
+ system_prompt_hash,
+ tools_signature,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# stable_hash + signature helpers
+# ---------------------------------------------------------------------------
+
+
+def test_stable_hash_is_deterministic_across_calls() -> None:
+ a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
+ b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
+ assert a == b
+
+
+def test_stable_hash_changes_when_any_part_changes() -> None:
+ base = stable_hash("v1", 42, "thread-9")
+ assert stable_hash("v1", 42, "thread-10") != base
+ assert stable_hash("v2", 42, "thread-9") != base
+ assert stable_hash("v1", 43, "thread-9") != base
+
+
+def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
+ """Two tool lists with the same surface must hash identically.
+
+ The cache key MUST NOT change when the underlying ``BaseTool``
+ instances are different Python objects (a fresh request constructs
+ fresh tool instances every time). Hashing on ``(name, description)``
+ keeps the cache hot across requests with identical tool surfaces.
+ """
+
+ @dataclass
+ class FakeTool:
+ name: str
+ description: str
+
+ tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
+ tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
+ sig_a = tools_signature(
+ tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ sig_b = tools_signature(
+ tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ assert sig_a == sig_b, "tool order must not affect the signature"
+
+ # Adding a tool rotates the key.
+ tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
+ sig_c = tools_signature(
+ tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ assert sig_c != sig_a
+
+
+def test_tools_signature_rotates_when_connector_set_changes() -> None:
+ @dataclass
+ class FakeTool:
+ name: str
+ description: str
+
+ tools = [FakeTool("a", "x")]
+ base = tools_signature(
+ tools, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ added = tools_signature(
+ tools,
+ available_connectors=["NOTION", "SLACK"],
+ available_document_types=["FILE"],
+ )
+ assert base != added, "adding a connector must rotate the cache key"
+
+
+def test_flags_signature_changes_when_flag_flips() -> None:
+ @dataclass(frozen=True)
+ class Flags:
+ a: bool = True
+ b: bool = False
+
+ base = flags_signature(Flags())
+ flipped = flags_signature(Flags(b=True))
+ assert base != flipped
+
+
+def test_system_prompt_hash_is_stable_and_distinct() -> None:
+ p1 = "You are a helpful assistant."
+ p2 = "You are a helpful assistant!" # one-character delta
+ assert system_prompt_hash(p1) == system_prompt_hash(p1)
+ assert system_prompt_hash(p1) != system_prompt_hash(p2)
+
+
+# ---------------------------------------------------------------------------
+# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_cache_hit_returns_same_instance_on_second_call() -> None:
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+ builds = 0
+
+ async def builder() -> object:
+ nonlocal builds
+ builds += 1
+ return object()
+
+ a = await cache.get_or_build("k", builder=builder)
+ b = await cache.get_or_build("k", builder=builder)
+ assert a is b, "cache must return the SAME object across hits"
+ assert builds == 1, "builder must run exactly once"
+
+
+@pytest.mark.asyncio
+async def test_cache_different_keys_get_different_instances() -> None:
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ a = await cache.get_or_build("k1", builder=builder)
+ b = await cache.get_or_build("k2", builder=builder)
+ assert a is not b
+
+
+@pytest.mark.asyncio
+async def test_cache_stale_entries_get_rebuilt() -> None:
+ # ttl=0 means every read sees the entry as immediately stale.
+ cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
+ builds = 0
+
+ async def builder() -> object:
+ nonlocal builds
+ builds += 1
+ return object()
+
+ a = await cache.get_or_build("k", builder=builder)
+ b = await cache.get_or_build("k", builder=builder)
+ assert a is not b, "stale entry must rebuild a fresh instance"
+ assert builds == 2
+
+
+@pytest.mark.asyncio
+async def test_cache_evicts_lru_when_full() -> None:
+ cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ a = await cache.get_or_build("a", builder=builder)
+ _ = await cache.get_or_build("b", builder=builder)
+ # Re-touch "a" so "b" is now the LRU victim.
+ a_again = await cache.get_or_build("a", builder=builder)
+ assert a_again is a
+ # Inserting "c" should evict "b" (LRU), not "a".
+ _ = await cache.get_or_build("c", builder=builder)
+ assert cache.stats()["size"] == 2
+
+ # Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
+ a_hit = await cache.get_or_build("a", builder=builder)
+ assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
+
+
+@pytest.mark.asyncio
+async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
+ """Two concurrent get_or_build calls on the same key must share one builder."""
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+ build_started = asyncio.Event()
+ builds = 0
+
+ async def slow_builder() -> object:
+ nonlocal builds
+ builds += 1
+ build_started.set()
+ # Yield control so the second waiter can race against us.
+ await asyncio.sleep(0.05)
+ return object()
+
+ task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
+ # Wait until the first builder has started, then race a second waiter.
+ await build_started.wait()
+ task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
+
+ a, b = await asyncio.gather(task_a, task_b)
+ assert a is b, "coalesced waiters must observe the same value"
+ assert builds == 1, "concurrent cold misses must collapse to ONE build"
+
+
+@pytest.mark.asyncio
+async def test_cache_does_not_store_failed_builds() -> None:
+ """A builder that raises must NOT poison the cache.
+
+ The next caller for the same key must run the builder again (not
+ re-raise the cached exception).
+ """
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+ attempts = 0
+
+ async def flaky_builder() -> object:
+ nonlocal attempts
+ attempts += 1
+ if attempts == 1:
+ raise RuntimeError("transient")
+ return object()
+
+ with pytest.raises(RuntimeError, match="transient"):
+ await cache.get_or_build("k", builder=flaky_builder)
+
+ # Second call must retry — not re-raise the cached exception.
+ value = await cache.get_or_build("k", builder=flaky_builder)
+ assert value is not None
+ assert attempts == 2
+
+
+@pytest.mark.asyncio
+async def test_cache_invalidate_drops_entry() -> None:
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ a = await cache.get_or_build("k", builder=builder)
+ assert cache.invalidate("k") is True
+ b = await cache.get_or_build("k", builder=builder)
+ assert a is not b, "post-invalidation lookup must rebuild"
+
+
+@pytest.mark.asyncio
+async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
+ cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ await cache.get_or_build("user:1:thread:1", builder=builder)
+ await cache.get_or_build("user:1:thread:2", builder=builder)
+ await cache.get_or_build("user:2:thread:1", builder=builder)
+
+ removed = cache.invalidate_prefix("user:1:")
+ assert removed == 2
+ assert cache.stats()["size"] == 1
+
+ # The user:2 entry must still be hot (no rebuild).
+ survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
+ assert survivor_value is not None
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py
index 0c7bf17f6..f0161f605 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py
@@ -7,7 +7,9 @@ import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
+ end_turn,
get_cancel_event,
+ is_cancel_requested,
manager,
request_cancel,
reset_cancel,
@@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")
+
+
+def test_request_cancel_creates_event_for_unseen_thread() -> None:
+ thread_id = "never-seen-cancel"
+ reset_cancel(thread_id)
+
+ assert request_cancel(thread_id) is True
+ assert get_cancel_event(thread_id).is_set()
+ assert is_cancel_requested(thread_id) is True
+
+
+@pytest.mark.asyncio
+async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
+ thread_id = "forced-end-turn"
+ mw = BusyMutexMiddleware()
+ runtime = _Runtime(thread_id)
+
+ await mw.abefore_agent({}, runtime)
+ assert manager.lock_for(thread_id).locked()
+
+ request_cancel(thread_id)
+ assert is_cancel_requested(thread_id) is True
+
+ end_turn(thread_id)
+
+ assert not manager.lock_for(thread_id).locked()
+ assert not get_cancel_event(thread_id).is_set()
+ assert is_cancel_requested(thread_id) is False
+
+
+@pytest.mark.asyncio
+async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
+ """A stale aafter call from attempt A must not unlock attempt B.
+
+ Repro flow:
+ 1) attempt A acquires thread lock
+ 2) forced end_turn clears A so retry can proceed
+ 3) attempt B acquires same thread lock
+ 4) stale attempt-A aafter runs late
+
+ Expected: B lock remains held.
+ """
+ thread_id = "stale-aafter-lock"
+ runtime = _Runtime(thread_id)
+ attempt_a = BusyMutexMiddleware()
+ attempt_b = BusyMutexMiddleware()
+
+ await attempt_a.abefore_agent({}, runtime)
+ lock = manager.lock_for(thread_id)
+ assert lock.locked()
+
+ end_turn(thread_id)
+ assert not lock.locked()
+
+ await attempt_b.abefore_agent({}, runtime)
+ assert lock.locked()
+
+ # Stale cleanup from attempt A must not release attempt B's lock.
+ await attempt_a.aafter_agent({}, runtime)
+ assert lock.locked()
+
+ await attempt_b.aafter_agent({}, runtime)
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
index 38a70a443..6800be2af 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
@@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
+ "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
+ "SURFSENSE_ENABLE_AGENT_CACHE",
+ "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]:
monkeypatch.delenv(name, raising=False)
-def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
+def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
- assert flags.any_new_middleware_enabled() is False
+ assert flags.enable_context_editing is True
+ assert flags.enable_compaction_v2 is True
+ assert flags.enable_retry_after is True
+ assert flags.enable_model_fallback is False
+ assert flags.enable_model_call_limit is True
+ assert flags.enable_tool_call_limit is True
+ assert flags.enable_tool_call_repair is True
+ assert flags.enable_doom_loop is True
+ assert flags.enable_permission is True
+ assert flags.enable_busy_mutex is True
+ assert flags.enable_llm_tool_selector is False
+ assert flags.enable_skills is True
+ assert flags.enable_specialized_subagents is True
+ assert flags.enable_kb_planner_runnable is True
+ assert flags.enable_action_log is True
+ assert flags.enable_revert_route is True
+ assert flags.enable_stream_parity_v2 is True
+ assert flags.enable_plugin_loader is False
+ assert flags.enable_otel is False
+ # Phase 2: agent cache is now default-on (the prerequisite tool
+ # ``db_session`` refactor landed). The companion gp-subagent share
+ # flag stays default-off pending data on cold-miss frequency.
+ assert flags.enable_agent_cache is True
+ assert flags.enable_agent_cache_share_gp_subagent is False
+ assert flags.any_new_middleware_enabled() is True
def test_master_kill_switch_overrides_individual_flags(
@@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
+ "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
- # `enable_otel` is intentionally orthogonal — it does NOT count toward
- # ``any_new_middleware_enabled`` because OTel is observability-only and
- # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
- counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
-
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
- monkeypatch.setenv(env_name, "true")
+ monkeypatch.setenv(env_name, "false")
flags = reload_for_tests()
- assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
- if attr in counts_toward_middleware:
- assert flags.any_new_middleware_enabled() is True
- else:
- assert flags.any_new_middleware_enabled() is False
+ assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py
new file mode 100644
index 000000000..6c323d920
--- /dev/null
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py
@@ -0,0 +1,344 @@
+"""Tests for ``FlattenSystemMessageMiddleware``.
+
+The middleware exists to defend against Anthropic's "Found 5 cache_control
+blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
+the system message and the OpenRouter→Anthropic adapter redistributes
+``cache_control`` across all of them. The flattening collapses every
+all-text system content list to a single string before the LLM call.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+from langchain_core.messages import HumanMessage, SystemMessage
+
+from app.agents.new_chat.middleware.flatten_system import (
+ FlattenSystemMessageMiddleware,
+ _flatten_text_blocks,
+ _flattened_request,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# _flatten_text_blocks — pure helper, the heart of the middleware.
+# ---------------------------------------------------------------------------
+
+
+class TestFlattenTextBlocks:
+ def test_joins_text_blocks_with_double_newline(self) -> None:
+ blocks = [
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ ]
+ assert (
+ _flatten_text_blocks(blocks)
+ == "\n\n\n\n"
+ )
+
+ def test_handles_single_text_block(self) -> None:
+ blocks = [{"type": "text", "text": "only one"}]
+ assert _flatten_text_blocks(blocks) == "only one"
+
+ def test_handles_empty_list(self) -> None:
+ assert _flatten_text_blocks([]) == ""
+
+ def test_passes_through_bare_string_blocks(self) -> None:
+ # LangChain content can mix bare strings and dict blocks.
+ blocks = ["raw string", {"type": "text", "text": "dict block"}]
+ assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
+
+ def test_returns_none_for_image_block(self) -> None:
+ # System messages with images are rare — but we never want to
+ # silently lose the image payload by joining as text.
+ blocks = [
+ {"type": "text", "text": "look at this"},
+ {"type": "image_url", "image_url": {"url": "data:image/png..."}},
+ ]
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_returns_none_for_non_dict_non_str_block(self) -> None:
+ blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_returns_none_when_text_field_missing(self) -> None:
+ blocks = [{"type": "text"}] # no ``text`` key
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_returns_none_when_text_is_not_string(self) -> None:
+ blocks = [{"type": "text", "text": ["nested", "list"]}]
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_drops_cache_control_from_inner_blocks(self) -> None:
+ # The whole point: existing cache_control on inner blocks is
+ # discarded so LiteLLM's ``cache_control_injection_points`` can
+ # re-attach exactly one breakpoint after flattening.
+ blocks = [
+ {"type": "text", "text": "first"},
+ {
+ "type": "text",
+ "text": "second",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]
+ flattened = _flatten_text_blocks(blocks)
+ assert flattened == "first\n\nsecond"
+ assert "cache_control" not in flattened # type: ignore[operator]
+
+
+# ---------------------------------------------------------------------------
+# _flattened_request — decides when to override and when to no-op.
+# ---------------------------------------------------------------------------
+
+
+def _make_request(system_message: SystemMessage | None) -> Any:
+ """Build a minimal ModelRequest stub. We only need .system_message
+ and .override(system_message=...) — the middleware never touches
+ other fields.
+ """
+ request = MagicMock()
+ request.system_message = system_message
+
+ def override(**kwargs: Any) -> Any:
+ new_request = MagicMock()
+ new_request.system_message = kwargs.get(
+ "system_message", request.system_message
+ )
+ new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
+ new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
+ return new_request
+
+ request.override = override
+ return request
+
+
+class TestFlattenedRequest:
+ def test_collapses_multi_block_system_to_string(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ ]
+ )
+ request = _make_request(sys)
+ flattened = _flattened_request(request)
+
+ assert flattened is not None
+ assert isinstance(flattened.system_message, SystemMessage)
+ assert flattened.system_message.content == (
+ "\n\n\n\n\n\n\n\n"
+ )
+
+ def test_no_op_for_string_content(self) -> None:
+ sys = SystemMessage(content="already a string")
+ request = _make_request(sys)
+ assert _flattened_request(request) is None
+
+ def test_no_op_for_single_block_list(self) -> None:
+ # One block already produces one breakpoint — no need to flatten.
+ sys = SystemMessage(content=[{"type": "text", "text": "single"}])
+ request = _make_request(sys)
+ assert _flattened_request(request) is None
+
+ def test_no_op_when_system_message_missing(self) -> None:
+ request = _make_request(None)
+ assert _flattened_request(request) is None
+
+ def test_no_op_when_list_contains_non_text_block(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "look"},
+ {"type": "image_url", "image_url": {"url": "data:..."}},
+ ]
+ )
+ request = _make_request(sys)
+ assert _flattened_request(request) is None
+
+ def test_preserves_additional_kwargs_and_metadata(self) -> None:
+ # Defensive: nothing in the current chain sets these on a system
+ # message, but losing them silently when something does in the
+ # future would be a regression. ``name`` in particular is the only
+ # ``additional_kwargs`` field that ChatLiteLLM's
+ # ``_convert_message_to_dict`` propagates onto the wire.
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "a"},
+ {"type": "text", "text": "b"},
+ ],
+ additional_kwargs={"name": "surfsense_system", "x": 1},
+ response_metadata={"tokens": 42},
+ )
+ sys.id = "sys-msg-1"
+ request = _make_request(sys)
+
+ flattened = _flattened_request(request)
+ assert flattened is not None
+ assert flattened.system_message.content == "a\n\nb"
+ assert flattened.system_message.additional_kwargs == {
+ "name": "surfsense_system",
+ "x": 1,
+ }
+ assert flattened.system_message.response_metadata == {"tokens": 42}
+ assert flattened.system_message.id == "sys-msg-1"
+
+ def test_idempotent_when_run_twice(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "a"},
+ {"type": "text", "text": "b"},
+ ]
+ )
+ request = _make_request(sys)
+ first = _flattened_request(request)
+ assert first is not None
+
+ # Second pass on the already-flattened request should be a no-op.
+ # We re-wrap in a request stub since the helper inspects
+ # ``request.system_message.content``.
+ second_request = _make_request(first.system_message)
+ assert _flattened_request(second_request) is None
+
+
+# ---------------------------------------------------------------------------
+# Middleware integration — verify the handler sees a flattened request.
+# ---------------------------------------------------------------------------
+
+
+class TestMiddlewareWrap:
+ @pytest.mark.asyncio
+ async def test_async_passes_flattened_request_to_handler(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "alpha"},
+ {"type": "text", "text": "beta"},
+ ]
+ )
+ request = _make_request(sys)
+ captured: dict[str, Any] = {}
+
+ async def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ result = await mw.awrap_model_call(request, handler)
+
+ assert result == "ok"
+ assert isinstance(captured["request"].system_message, SystemMessage)
+ assert captured["request"].system_message.content == "alpha\n\nbeta"
+
+ @pytest.mark.asyncio
+ async def test_async_passes_through_when_already_string(self) -> None:
+ sys = SystemMessage(content="just a string")
+ request = _make_request(sys)
+ captured: dict[str, Any] = {}
+
+ async def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ await mw.awrap_model_call(request, handler)
+
+ # Same request object: no override happened.
+ assert captured["request"] is request
+
+ def test_sync_passes_flattened_request_to_handler(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "alpha"},
+ {"type": "text", "text": "beta"},
+ ]
+ )
+ request = _make_request(sys)
+ captured: dict[str, Any] = {}
+
+ def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ result = mw.wrap_model_call(request, handler)
+
+ assert result == "ok"
+ assert captured["request"].system_message.content == "alpha\n\nbeta"
+
+ def test_sync_passes_through_when_no_system_message(self) -> None:
+ request = _make_request(None)
+ captured: dict[str, Any] = {}
+
+ def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ mw.wrap_model_call(request, handler)
+ assert captured["request"] is request
+
+
+# ---------------------------------------------------------------------------
+# Regression guard — pin the worst-case shape that triggered the
+# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
+# downstream cache_control_injection_points can only place 1 breakpoint
+# on the system message regardless of provider redistribution quirks.
+# ---------------------------------------------------------------------------
+
+
+def test_regression_five_block_system_collapses_to_one_block() -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ ]
+ )
+ request = _make_request(sys)
+ flattened = _flattened_request(request)
+
+ assert flattened is not None
+ assert isinstance(flattened.system_message.content, str)
+ # The exact join doesn't matter for the cache_control accounting —
+ # only that there is exactly ONE content block when LiteLLM's
+ # AnthropicCacheControlHook later targets ``role: system``.
+ assert " None:
+ # Sanity: the middleware MUST NOT touch user messages — only the
+ # system message. Multi-block user content is the path that carries
+ # image attachments and would lose its image_url block on
+ # accidental flatten.
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "a"},
+ {"type": "text", "text": "b"},
+ ]
+ )
+ user = HumanMessage(
+ content=[
+ {"type": "text", "text": "look at this"},
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
+ ]
+ )
+ request = _make_request(sys)
+ request.messages = [user]
+
+ flattened = _flattened_request(request)
+ assert flattened is not None
+ # System flattened to string …
+ assert isinstance(flattened.system_message.content, str)
+ # … user message is untouched (the helper does not even look at it).
+ assert flattened.messages == [user]
+ assert isinstance(user.content, list)
+ assert len(user.content) == 2
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py
index 0bbdf37bf..d0ea73376 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py
@@ -27,6 +27,7 @@ class TestDefaultAutoApprovedToolsList:
expected = {
"create_gmail_draft",
"update_gmail_draft",
+ "create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
@@ -41,13 +42,12 @@ class TestDefaultAutoApprovedToolsList:
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
def test_send_tools_are_not_auto_approved(self) -> None:
- # External-broadcast tools must always prompt.
+ # External-broadcast / destructive tools must always prompt.
for tool_name in (
"send_gmail_email",
"send_discord_message",
"send_teams_message",
"delete_notion_page",
- "create_calendar_event",
"delete_calendar_event",
):
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py
new file mode 100644
index 000000000..4cf53969d
--- /dev/null
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py
@@ -0,0 +1,370 @@
+r"""Tests for ``apply_litellm_prompt_caching`` in
+:mod:`app.agents.new_chat.prompt_caching`.
+
+The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
+never activated for our LiteLLM stack) with LiteLLM-native multi-provider
+prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
+``litellm.completion(...)``. The tests below pin its public contract:
+
+1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
+ savings compound across multi-turn conversations on Anthropic-family
+ providers. ``index: 0`` is used (rather than ``role: system``) because
+ the deepagent stack accumulates multiple ``SystemMessage``\ s in
+ ``state["messages"]`` and ``role: system`` would tag every one of
+ them, blowing past Anthropic's 4-block ``cache_control`` cap.
+2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
+ single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
+ prompt-cache surface is available).
+3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
+ OpenAI-only kwargs because the router fans out across providers.
+4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
+5. Defensive: LLMs without a writable ``model_kwargs`` are silently
+ skipped rather than raising.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+
+from app.agents.new_chat.llm_config import AgentConfig
+from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Test doubles
+# ---------------------------------------------------------------------------
+
+
+class _FakeLLM:
+ """Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``.
+
+ The helper only inspects ``getattr(llm, "model_kwargs", None)``,
+ ``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple
+ object suffices — we don't need to spin up real LangChain/LiteLLM
+ machinery for unit tests of the helper's logic.
+ """
+
+ def __init__(
+ self,
+ model: str = "openai/gpt-4o",
+ model_kwargs: dict[str, Any] | None = None,
+ ) -> None:
+ self.model = model
+ self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {}
+
+
+class ChatLiteLLMRouter:
+ """Class-name-only impostor of the real router.
+
+ The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"``
+ (a deliberate stringly-typed check to avoid an import cycle with
+ ``app.services.llm_router_service``). Reusing the same class name here
+ triggers the same code path without instantiating a real ``Router``.
+ """
+
+ def __init__(self) -> None:
+ self.model = "auto"
+ self.model_kwargs: dict[str, Any] = {}
+
+
+def _make_cfg(**overrides: Any) -> AgentConfig:
+ """Build an ``AgentConfig`` with sensible defaults for the helper test."""
+ defaults: dict[str, Any] = {
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "k",
+ }
+ return AgentConfig(**{**defaults, **overrides})
+
+
+# ---------------------------------------------------------------------------
+# (a) Universal injection points
+# ---------------------------------------------------------------------------
+
+
+def test_sets_both_cache_control_injection_points_with_no_config() -> None:
+ """Bare call (no agent_config, no thread_id) still sets the two
+ universal breakpoints — these cost nothing on providers that don't
+ consume them and unlock caching on every supported provider."""
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm)
+
+ points = llm.model_kwargs["cache_control_injection_points"]
+ assert {"location": "message", "index": 0} in points
+ assert {"location": "message", "index": -1} in points
+ assert len(points) == 2
+
+
+def test_does_not_inject_role_system_breakpoint() -> None:
+ """Regression: deliberately AVOID ``role: system`` so we don't tag
+ every SystemMessage the deepagent ``before_agent`` injectors push
+ into ``state["messages"]`` (priority, tree, memory, file-intent,
+ anonymous-doc). Tagging all of them overflows Anthropic's 4-block
+ ``cache_control`` cap and surfaces as
+ ``OpenrouterException: A maximum of 4 blocks with cache_control may
+ be provided. Found N`` 400s.
+ """
+ llm = _FakeLLM()
+ apply_litellm_prompt_caching(llm)
+ points = llm.model_kwargs["cache_control_injection_points"]
+ assert all(p.get("role") != "system" for p in points), (
+ f"Expected no role=system breakpoint, got: {points}"
+ )
+
+
+def test_injection_points_set_for_anthropic_config() -> None:
+ """Anthropic-family configs need the marker — verify it lands."""
+ cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
+ llm = _FakeLLM(model="anthropic/claude-3-5-sonnet")
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg)
+
+ assert "cache_control_injection_points" in llm.model_kwargs
+
+
+# ---------------------------------------------------------------------------
+# (b) Idempotency / user override wins
+# ---------------------------------------------------------------------------
+
+
+def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None:
+ """Users who set their own injection points (e.g. with ``ttl: "1h"``
+ via ``litellm_params``) keep them — the helper merges, never
+ clobbers."""
+ user_points = [
+ {"location": "message", "role": "system", "ttl": "1h"},
+ ]
+ llm = _FakeLLM(
+ model_kwargs={"cache_control_injection_points": user_points},
+ )
+
+ apply_litellm_prompt_caching(llm)
+
+ assert llm.model_kwargs["cache_control_injection_points"] is user_points
+
+
+def test_idempotent_when_called_multiple_times() -> None:
+ """Build-time + thread-time double-call must be a no-op the second time."""
+ cfg = _make_cfg(provider="OPENAI")
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
+ snapshot = {
+ "cache_control_injection_points": list(
+ llm.model_kwargs["cache_control_injection_points"]
+ ),
+ "prompt_cache_key": llm.model_kwargs["prompt_cache_key"],
+ "prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"],
+ }
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
+
+ assert (
+ llm.model_kwargs["cache_control_injection_points"]
+ == snapshot["cache_control_injection_points"]
+ )
+ assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"]
+ assert (
+ llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"]
+ )
+
+
+def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
+ """A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via
+ ``litellm_params``) wins over our default per-thread key."""
+ cfg = _make_cfg(provider="OPENAI")
+ llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"})
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
+
+ assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc"
+
+
+# ---------------------------------------------------------------------------
+# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
+def test_sets_openai_family_extras(provider: str) -> None:
+ """OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
+ via routing affinity) and ``prompt_cache_retention="24h"`` (extends
+ cache TTL beyond the default 5-10 min)."""
+ cfg = _make_cfg(provider=provider)
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
+
+ assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
+ assert llm.model_kwargs["prompt_cache_retention"] == "24h"
+
+
+def test_skips_prompt_cache_key_when_no_thread_id() -> None:
+ """Without a thread id we can't construct a per-thread key. Retention
+ is still useful so we set it (it's free)."""
+ cfg = _make_cfg(provider="OPENAI")
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
+
+ assert "prompt_cache_key" not in llm.model_kwargs
+ assert llm.model_kwargs["prompt_cache_retention"] == "24h"
+
+
+@pytest.mark.parametrize(
+ "provider",
+ ["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
+)
+def test_no_openai_extras_for_other_providers(provider: str) -> None:
+ """Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
+ skip it. ``cache_control_injection_points`` is still set (universal)."""
+ cfg = _make_cfg(provider=provider)
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
+
+ assert "prompt_cache_key" not in llm.model_kwargs
+ assert "prompt_cache_retention" not in llm.model_kwargs
+ assert "cache_control_injection_points" in llm.model_kwargs
+
+
+def test_no_openai_extras_in_auto_mode() -> None:
+ """Auto-mode fans out across mixed providers — we can't statically
+ target OpenAI-only kwargs."""
+ cfg = AgentConfig.from_auto_mode()
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
+
+ assert "prompt_cache_key" not in llm.model_kwargs
+ assert "prompt_cache_retention" not in llm.model_kwargs
+ assert "cache_control_injection_points" in llm.model_kwargs
+
+
+def test_no_openai_extras_for_custom_provider() -> None:
+ """Custom providers route through arbitrary user-supplied prefixes —
+ we don't try to infer OpenAI-family compatibility."""
+ cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy")
+ llm = _FakeLLM()
+
+ apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
+
+ assert "prompt_cache_key" not in llm.model_kwargs
+ assert "prompt_cache_retention" not in llm.model_kwargs
+
+
+# ---------------------------------------------------------------------------
+# (d) ChatLiteLLMRouter — universal injection points only
+# ---------------------------------------------------------------------------
+
+
+def test_router_llm_gets_only_universal_injection_points() -> None:
+ """Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must
+ receive only the universal injection points — its requests dispatch
+ across provider deployments and OpenAI-only kwargs would be wasted
+ (or stripped by ``drop_params``) on non-OpenAI legs."""
+ router = ChatLiteLLMRouter()
+ cfg = _make_cfg(provider="OPENAI")
+
+ apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42)
+
+ assert "cache_control_injection_points" in router.model_kwargs
+ assert "prompt_cache_key" not in router.model_kwargs
+ assert "prompt_cache_retention" not in router.model_kwargs
+
+
+# ---------------------------------------------------------------------------
+# (e) Defensive paths
+# ---------------------------------------------------------------------------
+
+
+def test_handles_llm_with_no_writable_model_kwargs() -> None:
+ """Some LLM implementations (e.g. fakes / minimal subclasses) don't
+ expose a writable ``model_kwargs``. The helper must skip silently —
+ raising would crash the entire LLM build path on a non-critical
+ optimisation."""
+
+ class _ImmutableLLM:
+ # ``__slots__`` blocks attribute creation, so ``setattr`` raises.
+ __slots__ = ("model",)
+
+ def __init__(self) -> None:
+ self.model = "openai/gpt-4o"
+
+ llm = _ImmutableLLM()
+
+ apply_litellm_prompt_caching(llm)
+
+
+def test_initialises_missing_model_kwargs_dict() -> None:
+ """When ``model_kwargs`` is present-but-None (Pydantic v2 default
+ pattern when no factory is set), the helper initialises it to an
+ empty dict before mutating."""
+
+ class _LazyLLM:
+ def __init__(self) -> None:
+ self.model = "openai/gpt-4o"
+ self.model_kwargs: dict[str, Any] | None = None
+
+ llm = _LazyLLM()
+
+ apply_litellm_prompt_caching(llm)
+
+ assert isinstance(llm.model_kwargs, dict)
+ assert "cache_control_injection_points" in llm.model_kwargs
+
+
+def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None:
+ """Direct caller path (e.g. ``create_chat_litellm_from_config`` for
+ YAML configs without a structured ``AgentConfig``): without
+ ``agent_config`` the helper sets only the universal injection points
+ — no OpenAI-family extras even if the prefix says ``openai/``.
+ Conservative: we'd rather miss the speedup than silently misroute."""
+ llm = _FakeLLM(model="openai/gpt-4o")
+
+ apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99)
+
+ assert "cache_control_injection_points" in llm.model_kwargs
+ assert "prompt_cache_key" not in llm.model_kwargs
+ assert "prompt_cache_retention" not in llm.model_kwargs
+
+
+# ---------------------------------------------------------------------------
+# (f) drop_params safety net (regression guard for #19346)
+# ---------------------------------------------------------------------------
+
+
+def test_litellm_drop_params_is_globally_enabled() -> None:
+ """``litellm.drop_params=True`` is set globally in
+ :mod:`app.services.llm_service` so any ``prompt_cache_key`` /
+ ``prompt_cache_retention`` we set on an OpenAI-family config is
+ auto-stripped if the request later routes to a non-supporting
+ provider (e.g. via auto-mode router fallback). This test pins that
+ invariant — losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``.
+ """
+ import litellm
+
+ import app.services.llm_service # noqa: F401 (side-effect: sets globals)
+
+ assert litellm.drop_params is True
+
+
+# ---------------------------------------------------------------------------
+# Regression note: LiteLLM #15696 (multi-content-block last message)
+# ---------------------------------------------------------------------------
+#
+# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]``
+# would get ``cache_control`` applied to *every* content block instead
+# of only the last one — wasting cache breakpoints and triggering 400s
+# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in
+# https://github.com/BerriAI/litellm/pull/15699.
+#
+# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix).
+# An end-to-end behavioural test would need to run ``litellm.completion``
+# through the Anthropic transformer, which is integration territory and
+# better covered by LiteLLM's own test suite. The unit guard here is the
+# version pin plus the build-time ``model_kwargs`` shape we verify above.
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py
new file mode 100644
index 000000000..ffe3dbaa4
--- /dev/null
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py
@@ -0,0 +1,117 @@
+"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`.
+
+The helper picks the model id fed to ``detect_provider_variant`` so the
+right ```` block lands in the system prompt. The tests
+below pin its preference order:
+
+1. ``agent_config.litellm_params["base_model"]`` (Azure-correct).
+2. ``agent_config.model_name``.
+3. ``getattr(llm, "model", None)``.
+
+Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would
+silently miss every provider regex.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
+from app.agents.new_chat.llm_config import AgentConfig
+
+pytestmark = pytest.mark.unit
+
+
+def _make_cfg(**overrides) -> AgentConfig:
+ """Build an ``AgentConfig`` with sensible defaults for the helper test."""
+ defaults = {
+ "provider": "OPENAI",
+ "model_name": "x",
+ "api_key": "k",
+ }
+ return AgentConfig(**{**defaults, **overrides})
+
+
+class _FakeLLM:
+ """Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance.
+
+ The resolver only reads the ``.model`` attribute via ``getattr``,
+ matching the established idiom in ``knowledge_search.py`` /
+ ``stream_new_chat.py`` / ``document_summarizer.py``.
+ """
+
+ def __init__(self, model: str | None) -> None:
+ self.model = model
+
+
+def test_prefers_litellm_params_base_model_over_deployment_name() -> None:
+ """Azure deployment slug must NOT shadow the underlying model family.
+
+ This is the failure mode the helper exists to prevent: a deployment
+ named ``"azure/prod-chat-001"`` would not match any provider regex
+ on its own, but the family ``"gpt-4o"`` lives in
+ ``litellm_params["base_model"]`` and routes to ``openai_classic``.
+ """
+ cfg = _make_cfg(
+ model_name="azure/prod-chat-001",
+ litellm_params={"base_model": "gpt-4o"},
+ )
+ assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o"
+
+
+def test_falls_back_to_model_name_when_litellm_params_is_none() -> None:
+ cfg = _make_cfg(
+ model_name="anthropic/claude-3-5-sonnet",
+ litellm_params=None,
+ )
+ got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet"))
+ assert got == "anthropic/claude-3-5-sonnet"
+
+
+def test_handles_litellm_params_without_base_model_key() -> None:
+ cfg = _make_cfg(
+ model_name="openai/gpt-4o",
+ litellm_params={"temperature": 0.5},
+ )
+ assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
+
+
+def test_ignores_blank_base_model() -> None:
+ """Whitespace-only ``base_model`` must not shadow ``model_name``."""
+ cfg = _make_cfg(
+ model_name="openai/gpt-4o",
+ litellm_params={"base_model": " "},
+ )
+ assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
+
+
+def test_ignores_non_string_base_model() -> None:
+ """Defensive: a non-string ``base_model`` should not crash the resolver."""
+ cfg = _make_cfg(
+ model_name="openai/gpt-4o",
+ litellm_params={"base_model": 42},
+ )
+ assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
+
+
+def test_falls_back_to_llm_model_when_no_agent_config() -> None:
+ """No ``agent_config`` -> use ``llm.model`` directly. Defensive path
+ for direct callers; production callers always supply a config."""
+ assert (
+ _resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini"))
+ == "openai/gpt-4o-mini"
+ )
+
+
+def test_returns_none_when_nothing_available() -> None:
+ """``compose_system_prompt`` treats ``None`` as the ``"default"``
+ variant and emits no provider block."""
+ assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None
+
+
+def test_auto_mode_resolves_to_auto_string() -> None:
+ """Auto mode -> ``"auto"``. ``detect_provider_variant("auto")``
+ returns ``"default"``, which is correct: the child model isn't
+ known until the LiteLLM Router dispatches."""
+ cfg = AgentConfig.from_auto_mode()
+ assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto"
diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py
index 2ca470680..2933a0504 100644
--- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py
+++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py
@@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
)
)
assert plan.is_recency_query is False
+
+
+# ── mentioned_document_ids cross-turn drain ────────────────────────────
+
+
+class TestKnowledgePriorityMentionDrain:
+ """Regression tests for the cross-turn ``mentioned_document_ids`` drain.
+
+ The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
+ instance across turns of the same thread. ``mentioned_document_ids``
+ can therefore enter the middleware via two paths:
+
+ 1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
+ seeded by the cache-miss build on turn 1.
+ 2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
+ turn by the streaming task.
+
+ Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
+ on turn 2 would fall through to the closure (because ``[]`` is falsy in
+ Python) and replay turn 1's mentions. This class pins down the
+ correct behaviour: the runtime path is authoritative even when empty,
+ and the closure is drained the first time the runtime path fires so
+ no later turn can ever resurrect stale state.
+ """
+
+ @staticmethod
+ def _make_runtime(mention_ids: list[int]):
+ """Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
+ from types import SimpleNamespace
+
+ return SimpleNamespace(
+ context=SimpleNamespace(mentioned_document_ids=mention_ids),
+ )
+
+ @staticmethod
+ def _planner_llm() -> "FakeLLM":
+ # Planner returns a stable, non-recency plan so we always land in
+ # the hybrid-search branch (where ``fetch_mentioned_documents`` is
+ # invoked alongside the main search).
+ return FakeLLM(
+ json.dumps(
+ {
+ "optimized_query": "follow up question",
+ "start_date": None,
+ "end_date": None,
+ "is_recency_query": False,
+ }
+ )
+ )
+
+ async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
+ """Turn 1 with mentions in BOTH closure and runtime context: the
+ runtime path wins AND the closure is drained so a future turn
+ cannot replay it.
+ """
+ fetched_ids: list[list[int]] = []
+
+ async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
+ fetched_ids.append(list(document_ids))
+ return []
+
+ async def fake_search_knowledge_base(**_kwargs):
+ return []
+
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
+ fake_fetch_mentioned_documents,
+ )
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
+ fake_search_knowledge_base,
+ )
+
+ middleware = KnowledgeBaseSearchMiddleware(
+ llm=self._planner_llm(),
+ search_space_id=42,
+ mentioned_document_ids=[1, 2, 3],
+ )
+
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="what is in those docs?")]},
+ runtime=self._make_runtime([1, 2, 3]),
+ )
+
+ assert fetched_ids == [[1, 2, 3]], (
+ "runtime.context mentions must be the source of truth on turn 1"
+ )
+ assert middleware.mentioned_document_ids == [], (
+ "closure must be drained the first time the runtime path fires "
+ "so no later turn can replay stale mentions"
+ )
+
+ async def test_empty_runtime_context_does_not_replay_closure_mentions(
+ self, monkeypatch
+ ):
+ """Regression: turn 2 with NO mentions must not surface turn 1's
+ mentions from the constructor closure.
+
+ Before the fix, ``if ctx_mentions:`` treated an empty list as
+ absent and fell through to ``elif self.mentioned_document_ids:``,
+ replaying turn 1's mentions. This test pins down the corrected
+ behaviour.
+ """
+ fetched_ids: list[list[int]] = []
+
+ async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
+ fetched_ids.append(list(document_ids))
+ return []
+
+ async def fake_search_knowledge_base(**_kwargs):
+ return []
+
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
+ fake_fetch_mentioned_documents,
+ )
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
+ fake_search_knowledge_base,
+ )
+
+ # Simulate a cached middleware instance whose closure was seeded
+ # by a previous turn's cache-miss build (mentions=[1,2,3]).
+ middleware = KnowledgeBaseSearchMiddleware(
+ llm=self._planner_llm(),
+ search_space_id=42,
+ mentioned_document_ids=[1, 2, 3],
+ )
+
+ # Turn 2: streaming task supplies an EMPTY mention list (no
+ # mentions on this follow-up turn).
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="what about the next steps?")]},
+ runtime=self._make_runtime([]),
+ )
+
+ assert fetched_ids == [], (
+ "fetch_mentioned_documents must NOT be called when the runtime "
+ "context says there are no mentions for this turn"
+ )
+
+ async def test_legacy_path_fires_only_when_runtime_context_absent(
+ self, monkeypatch
+ ):
+ """Backward-compat: if a caller doesn't supply runtime.context (old
+ non-streaming code path), the closure-injected mentions are still
+ honoured exactly once and then drained.
+ """
+ fetched_ids: list[list[int]] = []
+
+ async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
+ fetched_ids.append(list(document_ids))
+ return []
+
+ async def fake_search_knowledge_base(**_kwargs):
+ return []
+
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
+ fake_fetch_mentioned_documents,
+ )
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
+ fake_search_knowledge_base,
+ )
+
+ middleware = KnowledgeBaseSearchMiddleware(
+ llm=self._planner_llm(),
+ search_space_id=42,
+ mentioned_document_ids=[7, 8],
+ )
+
+ # First call: no runtime → legacy path uses the closure.
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="initial question")]},
+ runtime=None,
+ )
+ # Second call: still no runtime — closure already drained, so no replay.
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="follow up")]},
+ runtime=None,
+ )
+
+ assert fetched_ids == [[7, 8]], (
+ "legacy path must honour the closure exactly once and then drain it"
+ )
+ assert middleware.mentioned_document_ids == []
diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py
new file mode 100644
index 000000000..c9f18d77d
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py
@@ -0,0 +1,110 @@
+"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
+endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
+
+There is no DB column for ``supports_image_input`` on
+``NewLLMConfig`` — the value is resolved at the API boundary by
+``derive_supports_image_input`` so the new-chat selector / streaming
+task can read the same field shape regardless of source (BYOK vs YAML
+vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
+user out of their own model choice.
+"""
+
+from __future__ import annotations
+
+from datetime import UTC, datetime
+from types import SimpleNamespace
+from uuid import uuid4
+
+import pytest
+
+from app.db import LiteLLMProvider
+from app.routes import new_llm_config_routes
+
+pytestmark = pytest.mark.unit
+
+
+def _byok_row(
+ *,
+ id_: int,
+ model_name: str,
+ base_model: str | None = None,
+ provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
+ custom_provider: str | None = None,
+) -> object:
+ """Mimic the SQLAlchemy row's attribute surface; ``model_validate``
+ walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
+
+ ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
+ enum validator accepts it — same as the ORM row would carry."""
+ return SimpleNamespace(
+ id=id_,
+ name=f"BYOK-{id_}",
+ description=None,
+ provider=provider,
+ custom_provider=custom_provider,
+ model_name=model_name,
+ api_key="sk-byok",
+ api_base=None,
+ litellm_params={"base_model": base_model} if base_model else None,
+ system_instructions="",
+ use_default_system_instructions=True,
+ citations_enabled=True,
+ created_at=datetime.now(tz=UTC),
+ search_space_id=42,
+ user_id=uuid4(),
+ )
+
+
+def test_serialize_byok_known_vision_model_resolves_true():
+ """The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
+ True. The serialized row carries that value through to the
+ ``NewLLMConfigRead`` schema."""
+ row = _byok_row(id_=1, model_name="gpt-4o")
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+
+ assert serialized.supports_image_input is True
+ assert serialized.id == 1
+ assert serialized.model_name == "gpt-4o"
+
+
+def test_serialize_byok_unknown_model_default_allows():
+ """Unknown / unmapped: default-allow. The streaming-task safety net
+ is the actual block, and it requires LiteLLM to *explicitly* say
+ text-only — so a brand new BYOK model should not be pre-judged."""
+ row = _byok_row(
+ id_=2,
+ model_name="brand-new-model-x9-unmapped",
+ provider=LiteLLMProvider.CUSTOM,
+ custom_provider="brand_new_proxy",
+ )
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+
+ assert serialized.supports_image_input is True
+
+
+def test_serialize_byok_uses_base_model_when_present():
+ """Azure-style: ``model_name`` is the deployment id, ``base_model``
+ inside ``litellm_params`` is the canonical sku LiteLLM knows. The
+ helper must consult ``base_model`` first or unrecognised deployment
+ ids would shadow the real capability."""
+ row = _byok_row(
+ id_=3,
+ model_name="my-azure-deployment-id-no-litellm-knows-this",
+ base_model="gpt-4o",
+ provider=LiteLLMProvider.AZURE_OPENAI,
+ )
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+
+ assert serialized.supports_image_input is True
+
+
+def test_serialize_byok_returns_pydantic_read_model():
+ """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
+ the schema additions are guaranteed to be present in the API
+ surface. This guards against a future regression where someone
+ deletes the augmentation step and falls back to ORM passthrough."""
+ from app.schemas import NewLLMConfigRead
+
+ row = _byok_row(id_=4, model_name="gpt-4o")
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+ assert isinstance(serialized, NewLLMConfigRead)
diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py
new file mode 100644
index 000000000..2b6c76485
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py
@@ -0,0 +1,184 @@
+"""Unit tests for ``is_premium`` derivation on the global image-gen and
+vision-LLM list endpoints.
+
+Chat globals (``GET /global-llm-configs``) already emit
+``is_premium = (billing_tier == "premium")``. Image and vision did not,
+which made the new-chat ``model-selector`` render the Free/Premium badge
+on the Chat tab but skip it on the Image and Vision tabs (the selector
+keys its badge logic off ``is_premium``). These tests pin parity:
+
+* YAML free entry → ``is_premium=False``
+* YAML premium entry → ``is_premium=True``
+* OpenRouter dynamic premium entry → ``is_premium=True``
+* Auto stub (always emitted when at least one config is present)
+ → ``is_premium=False``
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+_IMAGE_FIXTURE: list[dict] = [
+ {
+ "id": -1,
+ "name": "DALL-E 3",
+ "provider": "OPENAI",
+ "model_name": "dall-e-3",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ },
+ {
+ "id": -2,
+ "name": "GPT-Image 1 (premium)",
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "api_key": "sk-test",
+ "billing_tier": "premium",
+ },
+ {
+ "id": -20_001,
+ "name": "google/gemini-2.5-flash-image (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash-image",
+ "api_key": "sk-or-test",
+ "api_base": "https://openrouter.ai/api/v1",
+ "billing_tier": "premium",
+ },
+]
+
+
+_VISION_FIXTURE: list[dict] = [
+ {
+ "id": -1,
+ "name": "GPT-4o Vision",
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ },
+ {
+ "id": -2,
+ "name": "Claude 3.5 Sonnet (premium)",
+ "provider": "ANTHROPIC",
+ "model_name": "claude-3-5-sonnet",
+ "api_key": "sk-ant-test",
+ "billing_tier": "premium",
+ },
+ {
+ "id": -30_001,
+ "name": "openai/gpt-4o (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "api_key": "sk-or-test",
+ "api_base": "https://openrouter.ai/api/v1",
+ "billing_tier": "premium",
+ },
+]
+
+
+# =============================================================================
+# Image generation
+# =============================================================================
+
+
+@pytest.mark.asyncio
+async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
+ """Each emitted config must carry ``is_premium`` derived server-side
+ from ``billing_tier``. The Auto stub is always free.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
+ )
+
+ payload = await image_generation_routes.get_global_image_gen_configs(user=None)
+
+ by_id = {c["id"]: c for c in payload}
+
+ # Auto stub is always emitted when at least one global config exists,
+ # and it must always declare itself free (Auto-mode billing-tier
+ # surfacing is a separate follow-up).
+ assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
+ assert by_id[0]["is_premium"] is False
+ assert by_id[0]["billing_tier"] == "free"
+
+ # YAML free entry — ``is_premium=False``
+ assert by_id[-1]["is_premium"] is False
+ assert by_id[-1]["billing_tier"] == "free"
+
+ # YAML premium entry — ``is_premium=True``
+ assert by_id[-2]["is_premium"] is True
+ assert by_id[-2]["billing_tier"] == "premium"
+
+ # OpenRouter dynamic premium entry — same field, same derivation
+ assert by_id[-20_001]["is_premium"] is True
+ assert by_id[-20_001]["billing_tier"] == "premium"
+
+ # Every emitted dict (including Auto) must have the field — never missing.
+ for cfg in payload:
+ assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
+ assert isinstance(cfg["is_premium"], bool)
+
+
+@pytest.mark.asyncio
+async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
+ """When there are no global configs at all, the endpoint emits an
+ empty list (no Auto stub) — Auto mode would have nothing to route to.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
+ payload = await image_generation_routes.get_global_image_gen_configs(user=None)
+ assert payload == []
+
+
+# =============================================================================
+# Vision LLM
+# =============================================================================
+
+
+@pytest.mark.asyncio
+async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
+ from app.config import config
+ from app.routes import vision_llm_routes
+
+ monkeypatch.setattr(
+ config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
+ )
+
+ payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
+
+ by_id = {c["id"]: c for c in payload}
+
+ assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
+ assert by_id[0]["is_premium"] is False
+ assert by_id[0]["billing_tier"] == "free"
+
+ assert by_id[-1]["is_premium"] is False
+ assert by_id[-1]["billing_tier"] == "free"
+
+ assert by_id[-2]["is_premium"] is True
+ assert by_id[-2]["billing_tier"] == "premium"
+
+ assert by_id[-30_001]["is_premium"] is True
+ assert by_id[-30_001]["billing_tier"] == "premium"
+
+ for cfg in payload:
+ assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
+ assert isinstance(cfg["is_premium"], bool)
+
+
+@pytest.mark.asyncio
+async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
+ from app.config import config
+ from app.routes import vision_llm_routes
+
+ monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
+ payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
+ assert payload == []
diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py
new file mode 100644
index 000000000..b47d9134b
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py
@@ -0,0 +1,106 @@
+"""Unit tests for ``supports_image_input`` derivation on the chat global
+config endpoint (``GET /global-new-llm-configs``).
+
+Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
+
+1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
+ loader for operator overrides, or by the OpenRouter integration from
+ ``architecture.input_modalities``) — wins.
+2. ``derive_supports_image_input`` helper — default-allow on unknown
+ models, only False when LiteLLM / OR modalities are definitive.
+
+The flag is purely informational at the API boundary. The streaming
+task safety net (``is_known_text_only_chat_model``) is the actual block,
+and it requires LiteLLM to *explicitly* mark the model as text-only.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+_FIXTURE: list[dict] = [
+ {
+ "id": -1,
+ "name": "GPT-4o (explicit true)",
+ "description": "vision-capable, explicit YAML override",
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ "supports_image_input": True,
+ },
+ {
+ "id": -2,
+ "name": "DeepSeek V3 (explicit false)",
+ "description": "OpenRouter dynamic — modality-derived false",
+ "provider": "OPENROUTER",
+ "model_name": "deepseek/deepseek-v3.2-exp",
+ "api_key": "sk-or-test",
+ "api_base": "https://openrouter.ai/api/v1",
+ "billing_tier": "free",
+ "supports_image_input": False,
+ },
+ {
+ "id": -10_010,
+ "name": "Unannotated GPT-4o",
+ "description": "no flag set — resolver should derive True via LiteLLM",
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ # supports_image_input intentionally absent
+ },
+ {
+ "id": -10_011,
+ "name": "Unannotated unknown model",
+ "description": "unmapped — default-allow True",
+ "provider": "CUSTOM",
+ "custom_provider": "brand_new_proxy",
+ "model_name": "brand-new-model-x9",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ },
+]
+
+
+@pytest.mark.asyncio
+async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
+ """Each emitted chat config carries ``supports_image_input`` as a
+ bool. Explicit values win; unannotated entries are resolved via the
+ helper (default-allow True)."""
+ from app.config import config
+ from app.routes import new_llm_config_routes
+
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
+
+ payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
+ by_id = {c["id"]: c for c in payload}
+
+ # Auto stub: optimistic True so the user can keep Auto selected with
+ # vision-capable deployments somewhere in the pool.
+ assert 0 in by_id, "Auto stub should be emitted when configs exist"
+ assert by_id[0]["supports_image_input"] is True
+ assert by_id[0]["is_auto_mode"] is True
+
+ # Explicit True is preserved.
+ assert by_id[-1]["supports_image_input"] is True
+
+ # Explicit False is preserved (the exact failure mode the safety net
+ # guards against — DeepSeek V3 over OpenRouter would 404 with "No
+ # endpoints found that support image input").
+ assert by_id[-2]["supports_image_input"] is False
+
+ # Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
+ assert by_id[-10_010]["supports_image_input"] is True
+
+ # Unknown / unmapped model: default-allow rather than pre-judge.
+ assert by_id[-10_011]["supports_image_input"] is True
+
+ for cfg in payload:
+ assert "supports_image_input" in cfg, (
+ f"supports_image_input missing from {cfg.get('id')}"
+ )
+ assert isinstance(cfg["supports_image_input"], bool)
diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py
new file mode 100644
index 000000000..636b7de31
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py
@@ -0,0 +1,138 @@
+"""Unit tests for the image-generation route's billing-resolution helper.
+
+End-to-end "POST /image-generations returns 402" coverage requires the
+integration harness (real DB, real auth) and lives in
+``tests/integration/document_upload/`` alongside the other quota tests.
+This unit test focuses on the new ``_resolve_billing_for_image_gen``
+helper which:
+
+* Returns ``free`` for Auto mode, even when premium configs exist
+ (Auto-mode billing-tier surfacing is a follow-up).
+* Returns ``free`` for user-owned BYOK configs (positive IDs).
+* Returns the global config's ``billing_tier`` for negative IDs.
+* Honours the per-config ``quota_reserve_micros`` override when present.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_auto_mode(monkeypatch):
+ from app.routes import image_generation_routes
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, # Not consumed on this code path.
+ config_id=0, # IMAGE_GEN_AUTO_MODE_ID
+ search_space=search_space,
+ )
+ assert tier == "free"
+ assert model == "auto"
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_premium_global_config(monkeypatch):
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_IMAGE_GEN_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "billing_tier": "premium",
+ "quota_reserve_micros": 75_000,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash-image",
+ "billing_tier": "free",
+ },
+ ],
+ raising=False,
+ )
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+
+ # Premium with override.
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=-1, search_space=search_space
+ )
+ assert tier == "premium"
+ assert model == "openai/gpt-image-1"
+ assert reserve == 75_000
+
+ # Free, no override → falls back to default.
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=-2, search_space=search_space
+ )
+ assert tier == "free"
+ # Provider-prefixed model string for OpenRouter.
+ assert "google/gemini-2.5-flash-image" in model
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_user_owned_byok_is_free():
+ """User-owned BYOK configs (positive IDs) cost the user nothing on
+ our side — they pay the provider directly. Always free.
+ """
+ from app.routes import image_generation_routes
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=42, search_space=search_space
+ )
+ assert tier == "free"
+ assert model == "user_byok"
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
+ """When the request omits ``image_generation_config_id``, the helper
+ must consult the search space's default — so a search space pinned
+ to a premium global config still gates new requests by quota.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_IMAGE_GEN_CONFIGS",
+ [
+ {
+ "id": -7,
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "billing_tier": "premium",
+ }
+ ],
+ raising=False,
+ )
+
+ search_space = SimpleNamespace(image_generation_config_id=-7)
+ (
+ tier,
+ model,
+ _reserve,
+ ) = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=None, search_space=search_space
+ )
+ assert tier == "premium"
+ assert model == "openai/gpt-image-1"
diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
new file mode 100644
index 000000000..fa8819b39
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
@@ -0,0 +1,436 @@
+"""Unit tests for ``_resolve_agent_billing_for_search_space``.
+
+Validates the resolver used by Celery podcast/video tasks to compute
+``(owner_user_id, billing_tier, base_model)`` from a search space and its
+agent LLM config. The resolver mirrors chat's billing-resolution pattern at
+``stream_new_chat.py:2294-2351`` and is the single integration point that
+prevents Auto-mode podcast/video from leaking premium credit.
+
+Coverage:
+
+* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
+ global → returns ``("premium", )``.
+* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
+ global → returns ``("free", )``.
+* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
+ → always ``"free"``.
+* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
+ hitting the pin service.
+* Negative id (no Auto) → uses ``get_global_llm_config``'s
+ ``billing_tier``.
+* Positive id (user BYOK) → always ``"free"``.
+* Search space not found → raises ``ValueError``.
+* ``agent_llm_id`` is None → raises ``ValueError``.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from uuid import UUID, uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+
+class _FakeSession:
+ """Tiny AsyncSession stub.
+
+ ``responses`` is a list of objects to return from successive
+ ``execute()`` calls (in order). The resolver makes at most two
+ ``execute()`` calls (search-space lookup, then optionally NewLLMConfig
+ lookup), so two queued responses cover the matrix.
+ """
+
+ def __init__(self, responses: list):
+ self._responses = list(responses)
+
+ async def execute(self, _stmt):
+ if not self._responses:
+ return _FakeExecResult(None)
+ return _FakeExecResult(self._responses.pop(0))
+
+ async def commit(self) -> None:
+ pass
+
+
+@dataclass
+class _FakePinResolution:
+ resolved_llm_config_id: int
+ resolved_tier: str = "premium"
+ from_existing_pin: bool = False
+
+
+def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=42,
+ agent_llm_id=agent_llm_id,
+ user_id=user_id,
+ )
+
+
+def _make_byok_config(
+ *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
+) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=id_,
+ model_name=model_name,
+ litellm_params={"base_model": base_model} if base_model else {},
+ )
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
+ """Auto + thread → pin service resolves to negative-id premium config →
+ resolver returns ``("premium", )``."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ # Mock the pin service to return a concrete premium config id.
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ assert selected_llm_config_id == 0
+ assert thread_id == 99
+ return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
+
+ # Mock global config lookup to return a premium entry.
+ def _fake_get_global(cfg_id):
+ if cfg_id == -1:
+ return {
+ "id": -1,
+ "model_name": "gpt-5.4",
+ "billing_tier": "premium",
+ "litellm_params": {"base_model": "gpt-5.4"},
+ }
+ return None
+
+ # Lazy imports inside the resolver — patch the *target* modules so the
+ # imported names resolve to our fakes.
+ import app.services.auto_model_pin_service as pin_module
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "premium"
+ assert base_model == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
+ """Auto + thread → pin returns negative-id free config → resolver
+ returns ``("free", )``. Same path the pin service takes for
+ out-of-credit users (graceful degradation)."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
+
+ def _fake_get_global(cfg_id):
+ if cfg_id == -3:
+ return {
+ "id": -3,
+ "model_name": "openrouter/free-model",
+ "billing_tier": "free",
+ "litellm_params": {"base_model": "openrouter/free-model"},
+ }
+ return None
+
+ import app.services.auto_model_pin_service as pin_module
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "openrouter/free-model"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
+ """Auto + thread → pin returns positive-id BYOK config → resolver
+ returns ``("free", ...)`` (BYOK is always free per
+ ``AgentConfig.from_new_llm_config``)."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
+ byok_cfg = _make_byok_config(
+ id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
+ )
+ session = _FakeSession([search_space, byok_cfg])
+
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
+
+ import app.services.auto_model_pin_service as pin_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "anthropic/claude-3-haiku"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_without_thread_id_falls_back_to_free():
+ """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
+ the pin service. Forward-compat fallback for any future direct-API
+ entrypoint that doesn't have a chat thread."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=None
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "auto"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
+ """If the pin service raises ``ValueError`` (thread missing /
+ mismatched search space), the resolver should log and return free
+ rather than killing the whole task."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ async def _fake_resolve_pin(*args, **kwargs):
+ raise ValueError("thread missing")
+
+ import app.services.auto_model_pin_service as pin_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "auto"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_premium_global_returns_premium(monkeypatch):
+ """Explicit negative agent_llm_id → ``get_global_llm_config`` →
+ return its ``billing_tier``."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "gpt-5.4",
+ "billing_tier": "premium",
+ "litellm_params": {"base_model": "gpt-5.4"},
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "premium"
+ assert base_model == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_free_global_returns_free(monkeypatch):
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "openrouter/some-free",
+ "billing_tier": "free",
+ "litellm_params": {"base_model": "openrouter/some-free"},
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=None
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "openrouter/some-free"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
+ """When the global config has no ``litellm_params.base_model``, the
+ resolver falls back to ``model_name`` — matching chat's behavior."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "fallback-model",
+ "billing_tier": "premium",
+ # No litellm_params.
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ _, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert tier == "premium"
+ assert base_model == "fallback-model"
+
+
+@pytest.mark.asyncio
+async def test_positive_id_byok_is_always_free():
+ """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
+ regardless of underlying provider tier."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
+ byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
+ session = _FakeSession([search_space, byok_cfg])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "anthropic/claude-3.5-sonnet"
+
+
+@pytest.mark.asyncio
+async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
+ """If the BYOK config row is missing/deleted but the search space still
+ points at it, the resolver still returns free (no debit) with an empty
+ base_model — billable_call's premium path is skipped, no harm done."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == ""
+
+
+@pytest.mark.asyncio
+async def test_search_space_not_found_raises_value_error():
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ session = _FakeSession([None])
+
+ with pytest.raises(ValueError, match="Search space"):
+ await _resolve_agent_billing_for_search_space(session, search_space_id=999)
+
+
+@pytest.mark.asyncio
+async def test_agent_llm_id_none_raises_value_error():
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
+
+ with pytest.raises(ValueError, match="agent_llm_id"):
+ await _resolve_agent_billing_for_search_space(session, search_space_id=42)
diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
new file mode 100644
index 000000000..d1af29aeb
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
@@ -0,0 +1,1026 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+import pytest
+
+from app.services.auto_model_pin_service import (
+ clear_healthy,
+ clear_runtime_cooldown,
+ is_recently_healthy,
+ mark_healthy,
+ mark_runtime_cooldown,
+ resolve_or_get_pinned_llm_config_id,
+)
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.fixture(autouse=True)
+def _clear_runtime_cooldown_map():
+ clear_runtime_cooldown()
+ clear_healthy()
+ yield
+ clear_runtime_cooldown()
+ clear_healthy()
+
+
+@dataclass
+class _FakeQuotaResult:
+ allowed: bool
+
+
+class _FakeExecResult:
+ def __init__(self, thread):
+ self._thread = thread
+
+ def unique(self):
+ return self
+
+ def scalar_one_or_none(self):
+ return self._thread
+
+
+class _FakeSession:
+ def __init__(self, thread):
+ self.thread = thread
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self.thread)
+
+ async def commit(self):
+ self.commit_count += 1
+
+
+def _thread(
+ *,
+ search_space_id: int = 10,
+ pinned_llm_config_id: int | None = None,
+):
+ return SimpleNamespace(
+ id=1,
+ search_space_id=search_space_id,
+ pinned_llm_config_id=pinned_llm_config_id,
+ )
+
+
+@pytest.mark.asyncio
+async def test_auto_first_turn_pins_one_model(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
+ assert session.commit_count == 1
+
+
+@pytest.mark.asyncio
+async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-free",
+ "api_key": "k1",
+ "billing_tier": "free",
+ "quality_score": 100,
+ },
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ "quality_score": 10,
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.resolved_tier == "premium"
+
+
+@pytest.mark.asyncio
+async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.1",
+ "api_key": "k1",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 100,
+ },
+ {
+ "id": -2,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.4",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 10,
+ },
+ {
+ "id": -3,
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-5.4",
+ "api_key": "k3",
+ "billing_tier": "premium",
+ "auto_pin_tier": "B",
+ "quality_score": 100,
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.resolved_tier == "premium"
+
+
+@pytest.mark.asyncio
+async def test_next_turn_reuses_existing_pin(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ },
+ ],
+ )
+
+ async def _must_not_call(*_args, **_kwargs):
+ raise AssertionError(
+ "premium_get_usage should not be called for valid pin reuse"
+ )
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _must_not_call,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.from_existing_pin is True
+ assert session.commit_count == 0
+
+
+@pytest.mark.asyncio
+async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.resolved_tier == "premium"
+
+
+@pytest.mark.asyncio
+async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-free",
+ "api_key": "k1",
+ "billing_tier": "free",
+ },
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.resolved_tier == "free"
+
+
+@pytest.mark.asyncio
+async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-free",
+ "api_key": "k1",
+ "billing_tier": "free",
+ },
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.from_existing_pin is True
+
+
+@pytest.mark.asyncio
+async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-free",
+ "api_key": "k1",
+ "billing_tier": "free",
+ },
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ force_repin_free=True,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.resolved_tier == "free"
+ assert result.from_existing_pin is False
+ assert session.thread.pinned_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_explicit_user_model_change_clears_pin(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-2))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
+ ],
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=7,
+ )
+ assert result.resolved_llm_config_id == 7
+ assert session.thread.pinned_llm_config_id is None
+ assert session.commit_count == 1
+
+
+@pytest.mark.asyncio
+async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-999))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert session.thread.pinned_llm_config_id == -2
+ assert session.commit_count == 1
+
+
+# ---------------------------------------------------------------------------
+# Quality-aware pin selection (Auto Fastest upgrade)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
+ """A cfg flagged ``health_gated`` must never be picked even if it has
+ the highest score among eligible cfgs."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "venice/dead-model",
+ "api_key": "k1",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 95,
+ "health_gated": True,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-flash",
+ "api_key": "k1",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 60,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
+ """Premium-eligible users with Tier A available should never spill to
+ Tier B even if a B cfg ranks higher by ``quality_score``."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "api_key": "k-yaml",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 70,
+ "health_gated": False,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-5",
+ "api_key": "k-or",
+ "billing_tier": "premium",
+ "auto_pin_tier": "B",
+ "quality_score": 95,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.resolved_tier == "premium"
+
+
+@pytest.mark.asyncio
+async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch):
+ """Free-only user with no Tier A free cfg should pick from Tier C."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "api_key": "k-yaml",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 100,
+ "health_gated": False,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-flash:free",
+ "api_key": "k-or",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 60,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_top_k_picks_only_high_score_models(monkeypatch):
+ """Different thread IDs should spread across top-K, never pick the
+ obvious low-quality cfg even when it sits in the candidate list."""
+ from app.config import config
+
+ high_score_cfgs = [
+ {
+ "id": -i,
+ "provider": "AZURE_OPENAI",
+ "model_name": f"gpt-x-{i}",
+ "api_key": "k",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 90,
+ "health_gated": False,
+ }
+ for i in range(1, 6) # 5 high-quality Tier A cfgs
+ ]
+ low_score_trap = {
+ "id": -99,
+ "provider": "AZURE_OPENAI",
+ "model_name": "tiny-legacy",
+ "api_key": "k",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 10,
+ "health_gated": False,
+ }
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [*high_score_cfgs, low_score_trap],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ high_score_ids = {c["id"] for c in high_score_cfgs}
+ seen = set()
+ for thread_id in range(1, 50):
+ session = _FakeSession(_thread())
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=thread_id,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ seen.add(result.resolved_llm_config_id)
+ assert result.resolved_llm_config_id != -99, (
+ "low-score trap cfg should never be picked"
+ )
+ assert result.resolved_llm_config_id in high_score_ids
+
+ # Spread across at least a couple of top-K cfgs.
+ assert len(seen) > 1
+
+
+@pytest.mark.asyncio
+async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
+ """An *already* pinned cfg that later flips to ``health_gated`` should
+ still not be reused — gated cfgs are filtered out of the candidate
+ pool, which forces a repair to a healthy cfg.
+
+ This guards the no-silent-tier-switch invariant: we don't keep using
+ a known-broken model just because the thread happened to be pinned
+ to it before the gate fired."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "venice/dead-model",
+ "api_key": "k",
+ "billing_tier": "premium",
+ "auto_pin_tier": "B",
+ "quality_score": 50,
+ "health_gated": True,
+ },
+ {
+ "id": -2,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "api_key": "k",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 90,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is False
+
+
+@pytest.mark.asyncio
+async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
+ """Existing pin reuse must short-circuit the new tier/score logic."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "api_key": "k",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 50, # lower than -2
+ "health_gated": False,
+ },
+ {
+ "id": -2,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5-pro",
+ "api_key": "k",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 99,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _must_not_call(*_args, **_kwargs):
+ raise AssertionError("premium_get_usage should not run on pin reuse")
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _must_not_call,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.from_existing_pin is True
+ assert session.commit_count == 0
+
+
+@pytest.mark.asyncio
+async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
+ """A runtime-cooled config should be excluded from candidate reuse.
+
+ This enables one-shot recovery from transient provider 429 bursts: we can
+ mark the pinned cfg as cooled down and force a repair to another eligible
+ cfg on the next resolution.
+ """
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemma-4-26b-a4b-it:free",
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 90,
+ "health_gated": False,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash:free",
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 80,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is False
+
+
+@pytest.mark.asyncio
+async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemma-4-26b-a4b-it:free",
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 90,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _must_not_call(*_args, **_kwargs):
+ raise AssertionError("premium_get_usage should not run on healthy pin reuse")
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _must_not_call,
+ )
+
+ mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
+ clear_runtime_cooldown(-1)
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.from_existing_pin is True
+
+
+@pytest.mark.asyncio
+async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
+ """Runtime retry should never repin the just-failed config."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned_llm_config_id=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemma-4-26b-a4b-it:free",
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 90,
+ "health_gated": False,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash:free",
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "C",
+ "quality_score": 80,
+ "health_gated": False,
+ },
+ ],
+ )
+
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _blocked,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ exclude_config_ids={-1},
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is False
+
+
+# ---------------------------------------------------------------------------
+# Healthy-status cache (preflight TTL companion)
+# ---------------------------------------------------------------------------
+
+
+def test_mark_healthy_then_is_recently_healthy_true_within_ttl():
+ mark_healthy(-42, ttl_seconds=60)
+ assert is_recently_healthy(-42) is True
+
+
+def test_healthy_expires_after_ttl(monkeypatch):
+ import app.services.auto_model_pin_service as svc
+
+ real_time = svc.time.time
+ base = real_time()
+
+ monkeypatch.setattr(svc.time, "time", lambda: base)
+ mark_healthy(-7, ttl_seconds=10)
+ assert is_recently_healthy(-7) is True
+
+ monkeypatch.setattr(svc.time, "time", lambda: base + 11)
+ assert is_recently_healthy(-7) is False
+
+
+def test_mark_runtime_cooldown_invalidates_healthy_cache():
+ mark_healthy(-9, ttl_seconds=60)
+ assert is_recently_healthy(-9) is True
+
+ mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60)
+ assert is_recently_healthy(-9) is False
+
+
+def test_clear_healthy_removes_single_entry():
+ mark_healthy(-11, ttl_seconds=60)
+ mark_healthy(-12, ttl_seconds=60)
+ clear_healthy(-11)
+ assert is_recently_healthy(-11) is False
+ assert is_recently_healthy(-12) is True
+
+
+def test_clear_healthy_no_args_drops_all_entries():
+ mark_healthy(-21, ttl_seconds=60)
+ mark_healthy(-22, ttl_seconds=60)
+ clear_healthy()
+ assert is_recently_healthy(-21) is False
+ assert is_recently_healthy(-22) is False
diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py
new file mode 100644
index 000000000..0e19b80e4
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py
@@ -0,0 +1,286 @@
+"""Image-aware extension of the Auto-pin resolver.
+
+When the current chat turn carries an ``image_url`` block, the pin
+resolver must:
+
+1. Filter the candidate pool to vision-capable cfgs so a freshly
+ selected pin can never be text-only.
+2. Treat any existing pin whose capability is False as invalid (force
+ re-pin), even when it would otherwise be reused as the thread's
+ stable model.
+3. Raise ``ValueError`` (mapped to the friendly
+ ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
+ task) when no vision-capable cfg is available — instead of silently
+ pinning text-only and 404-ing at the provider.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+import pytest
+
+from app.services.auto_model_pin_service import (
+ clear_healthy,
+ clear_runtime_cooldown,
+ resolve_or_get_pinned_llm_config_id,
+)
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.fixture(autouse=True)
+def _reset_caches():
+ clear_runtime_cooldown()
+ clear_healthy()
+ yield
+ clear_runtime_cooldown()
+ clear_healthy()
+
+
+@dataclass
+class _FakeQuotaResult:
+ allowed: bool
+
+
+class _FakeExecResult:
+ def __init__(self, thread):
+ self._thread = thread
+
+ def unique(self):
+ return self
+
+ def scalar_one_or_none(self):
+ return self._thread
+
+
+class _FakeSession:
+ def __init__(self, thread):
+ self.thread = thread
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self.thread)
+
+ async def commit(self):
+ self.commit_count += 1
+
+
+def _thread(*, pinned: int | None = None):
+ return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
+
+
+def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
+ return {
+ "id": id_,
+ "provider": "OPENAI",
+ "model_name": f"vision-{id_}",
+ "api_key": "k",
+ "billing_tier": tier,
+ "supports_image_input": True,
+ "auto_pin_tier": "A",
+ "quality_score": quality,
+ }
+
+
+def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
+ return {
+ "id": id_,
+ "provider": "OPENAI",
+ "model_name": f"text-{id_}",
+ "api_key": "k",
+ "billing_tier": tier,
+ # Higher quality than the vision cfgs — so a bug that ignores
+ # the image flag would surface as the resolver picking this one.
+ "supports_image_input": False,
+ "auto_pin_tier": "A",
+ "quality_score": quality,
+ }
+
+
+async def _premium_allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+
+@pytest.mark.asyncio
+async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _vision_cfg(-2)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+ assert result.resolved_llm_config_id == -2
+ # The thread should be pinned to the vision cfg even though the
+ # text-only cfg has a higher quality score.
+ assert session.thread.pinned_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
+ """An existing text-only pin must be invalidated when the next turn
+ requires image input. The non-image path would happily reuse it."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _vision_cfg(-2)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is False
+ assert session.thread.pinned_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
+ """If the thread is already pinned to a vision-capable cfg, reuse it
+ — same as the non-image path. Image-aware filtering must not force
+ spurious re-pins."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned=-2))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is True
+
+
+@pytest.mark.asyncio
+async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
+ """The friendly-error path: no vision-capable cfg in the pool -> raise
+ ``ValueError`` whose message contains ``vision-capable`` so the
+ streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _text_only_cfg(-2)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ with pytest.raises(ValueError, match="vision-capable"):
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+
+@pytest.mark.asyncio
+async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
+ """Regression guard: the image flag must default False and not affect
+ a normal text-only turn — text-only cfgs remain selectable."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+
+
+@pytest.mark.asyncio
+async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
+ """A YAML cfg that omits ``supports_image_input`` falls through to
+ ``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
+ that returns True, so the cfg should be a valid candidate."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ cfg_unannotated_vision = {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-4o", # known vision model in LiteLLM map
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "A",
+ "quality_score": 80,
+ # NOTE: no supports_image_input key
+ }
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+ assert result.resolved_llm_config_id == -2
diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py
new file mode 100644
index 000000000..c820724ed
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_billable_call.py
@@ -0,0 +1,559 @@
+"""Unit tests for the ``billable_call`` async context manager.
+
+Covers the per-call premium-credit lifecycle for image generation and
+vision LLM extraction:
+
+* Free configs bypass reserve/finalize but still write an audit row.
+* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
+ route layer).
+* Successful premium calls reserve, yield the accumulator, then finalize
+ with the LiteLLM-reported actual cost — and write an audit row.
+* Failed premium calls release the reservation so credit isn't leaked.
+* All quota DB ops happen inside their OWN ``shielded_async_session``,
+ isolating them from the caller's transaction (issue A).
+"""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeQuotaResult:
+ def __init__(
+ self,
+ *,
+ allowed: bool,
+ used: int = 0,
+ limit: int = 5_000_000,
+ remaining: int = 5_000_000,
+ ) -> None:
+ self.allowed = allowed
+ self.used = used
+ self.limit = limit
+ self.remaining = remaining
+
+
+class _FakeSession:
+ """Minimal AsyncSession stub — record commits for assertion."""
+
+ def __init__(self) -> None:
+ self.committed = False
+ self.added: list[Any] = []
+
+ def add(self, obj: Any) -> None:
+ self.added.append(obj)
+
+ async def commit(self) -> None:
+ self.committed = True
+
+ async def rollback(self) -> None:
+ pass
+
+ async def close(self) -> None:
+ pass
+
+
+@contextlib.asynccontextmanager
+async def _fake_shielded_session():
+ s = _FakeSession()
+ _SESSIONS_USED.append(s)
+ yield s
+
+
+_SESSIONS_USED: list[_FakeSession] = []
+
+
+def _patch_isolation_layer(
+ monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
+):
+ """Wire fake reserve/finalize/release/session helpers."""
+ _SESSIONS_USED.clear()
+ reserve_calls: list[dict[str, Any]] = []
+ finalize_calls: list[dict[str, Any]] = []
+ release_calls: list[dict[str, Any]] = []
+
+ async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
+ reserve_calls.append(
+ {
+ "user_id": user_id,
+ "reserve_micros": reserve_micros,
+ "request_id": request_id,
+ }
+ )
+ return reserve_result
+
+ async def _fake_finalize(
+ *, db_session, user_id, request_id, actual_micros, reserved_micros
+ ):
+ if finalize_exc is not None:
+ raise finalize_exc
+ finalize_calls.append(
+ {
+ "user_id": user_id,
+ "actual_micros": actual_micros,
+ "reserved_micros": reserved_micros,
+ }
+ )
+ return finalize_result or _FakeQuotaResult(allowed=True)
+
+ async def _fake_release(*, db_session, user_id, reserved_micros):
+ release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
+
+ record_calls: list[dict[str, Any]] = []
+
+ async def _fake_record(session, **kwargs):
+ record_calls.append(kwargs)
+ return object()
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_reserve",
+ _fake_reserve,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_finalize",
+ _fake_finalize,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_release",
+ _fake_release,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.shielded_async_session",
+ _fake_shielded_session,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.record_token_usage",
+ _fake_record,
+ raising=False,
+ )
+
+ return {
+ "reserve": reserve_calls,
+ "finalize": finalize_calls,
+ "release": release_calls,
+ "record": record_calls,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openai/gpt-image-1",
+ usage_type="image_generation",
+ ) as acc:
+ # Simulate a captured cost — the accumulator is fed by the LiteLLM
+ # callback in real life, here we add it manually.
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=37_000,
+ call_kind="image_generation",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ # Free still audits.
+ assert len(spies["record"]) == 1
+ assert spies["record"][0]["usage_type"] == "image_generation"
+ assert spies["record"][0]["cost_micros"] == 37_000
+
+
+@pytest.mark.asyncio
+async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
+ from app.services.billable_calls import (
+ QuotaInsufficientError,
+ billable_call,
+ )
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(
+ allowed=False, used=5_000_000, limit=5_000_000, remaining=0
+ ),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(QuotaInsufficientError) as exc_info:
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ):
+ pytest.fail("body should not run when reserve is denied")
+
+ err = exc_info.value
+ assert err.usage_type == "image_generation"
+ assert err.used_micros == 5_000_000
+ assert err.limit_micros == 5_000_000
+ assert err.remaining_micros == 0
+ # Reserve was attempted, but no finalize/release on a denied reserve
+ # — the reservation never actually held credit.
+ assert len(spies["reserve"]) == 1
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ # Denied premium calls do NOT create an audit row (no work happened).
+ assert spies["record"] == []
+
+
+@pytest.mark.asyncio
+async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ) as acc:
+ # LiteLLM callback would normally fill this — simulate $0.04 image.
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert spies["reserve"][0]["reserve_micros"] == 50_000
+ assert len(spies["finalize"]) == 1
+ assert spies["finalize"][0]["actual_micros"] == 40_000
+ assert spies["finalize"][0]["reserved_micros"] == 50_000
+ assert spies["release"] == []
+ # And audit row written with the actual debited cost.
+ assert spies["record"][0]["cost_micros"] == 40_000
+ # Each quota op opened its OWN session — proves session isolation.
+ assert len(_SESSIONS_USED) >= 3
+ # Sessions used should each have committed (or be the audit one which commits).
+ for _s in _SESSIONS_USED:
+ # finalize/reserve happen via TokenQuotaService.* which we stub —
+ # they don't actually call commit on our fake session, but the
+ # audit session does. We just assert >=1 session committed.
+ pass
+ assert any(s.committed for s in _SESSIONS_USED)
+
+
+@pytest.mark.asyncio
+async def test_premium_failure_releases_reservation(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ class _ProviderError(Exception):
+ pass
+
+ with pytest.raises(_ProviderError):
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ):
+ raise _ProviderError("OpenRouter 503")
+
+ assert len(spies["reserve"]) == 1
+ assert spies["finalize"] == []
+ # Failure path: release the held reservation.
+ assert len(spies["release"]) == 1
+ assert spies["release"][0]["reserved_micros"] == 50_000
+
+
+@pytest.mark.asyncio
+async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
+ """When ``quota_reserve_micros_override`` is None we fall back to
+ ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
+ Vision LLM calls take this path (token-priced models).
+ """
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+
+ captured_estimator_calls: list[dict[str, Any]] = []
+
+ def _fake_estimate(*, base_model, quota_reserve_tokens):
+ captured_estimator_calls.append(
+ {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
+ )
+ return 12_345
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.estimate_call_reserve_micros",
+ _fake_estimate,
+ raising=False,
+ )
+
+ user_id = uuid4()
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ usage_type="vision_extraction",
+ ):
+ pass
+
+ assert captured_estimator_calls == [
+ {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
+ ]
+ assert spies["reserve"][0]["reserve_micros"] == 12_345
+
+
+@pytest.mark.asyncio
+async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
+ from app.services.billable_calls import BillingSettlementError, billable_call
+
+ class _FinalizeError(RuntimeError):
+ pass
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(allowed=True),
+ finalize_exc=_FinalizeError("db finalize failed"),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(BillingSettlementError):
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ) as acc:
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert len(spies["release"]) == 1
+ assert spies["record"] == []
+
+
+@pytest.mark.asyncio
+async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ class _HangingCommitSession(_FakeSession):
+ async def commit(self) -> None:
+ await asyncio.sleep(60)
+
+ @contextlib.asynccontextmanager
+ async def _hanging_session_factory():
+ s = _HangingCommitSession()
+ _SESSIONS_USED.append(s)
+ yield s
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ billable_session_factory=_hanging_session_factory,
+ audit_timeout_seconds=0.01,
+ ) as acc:
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert len(spies["finalize"]) == 1
+ assert len(spies["record"]) == 1
+ assert spies["release"] == []
+
+
+@pytest.mark.asyncio
+async def test_free_audit_failure_is_best_effort(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+
+ async def _failing_record(_session, **_kwargs):
+ raise RuntimeError("audit insert failed")
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.record_token_usage",
+ _failing_record,
+ raising=False,
+ )
+
+ async with billable_call(
+ user_id=uuid4(),
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openai/gpt-image-1",
+ usage_type="image_generation",
+ audit_timeout_seconds=0.01,
+ ) as acc:
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=37_000,
+ call_kind="image_generation",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+
+
+# ---------------------------------------------------------------------------
+# Podcast / video-presentation usage_type coverage
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
+ """Free podcast configs must skip reserve/finalize but still emit a
+ ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
+ have full audit coverage of free-tier agent runs."""
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openrouter/some-free-model",
+ quota_reserve_micros_override=200_000,
+ usage_type="podcast_generation",
+ thread_id=99,
+ call_details={"podcast_id": 7, "title": "Test Podcast"},
+ ) as acc:
+ # Two transcript LLM calls aggregated into one accumulator.
+ acc.add(
+ model="openrouter/some-free-model",
+ prompt_tokens=1500,
+ completion_tokens=8000,
+ total_tokens=9500,
+ cost_micros=0,
+ call_kind="chat",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+
+ assert len(spies["record"]) == 1
+ row = spies["record"][0]
+ assert row["usage_type"] == "podcast_generation"
+ assert row["thread_id"] is None
+ assert row["search_space_id"] == 42
+ assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
+
+
+@pytest.mark.asyncio
+async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
+ """Premium video-presentation runs that hit a denied reservation must
+ raise ``QuotaInsufficientError`` *before* the graph runs and must not
+ emit an audit row (no work happened)."""
+ from app.services.billable_calls import (
+ QuotaInsufficientError,
+ billable_call,
+ )
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(
+ allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
+ ),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(QuotaInsufficientError) as exc_info:
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="gpt-5.4",
+ quota_reserve_micros_override=1_000_000,
+ usage_type="video_presentation_generation",
+ thread_id=99,
+ call_details={"video_presentation_id": 12, "title": "Test Video"},
+ ):
+ pytest.fail("body should not run when reserve is denied")
+
+ err = exc_info.value
+ assert err.usage_type == "video_presentation_generation"
+ assert err.remaining_micros == 500_000
+ assert spies["reserve"][0]["reserve_micros"] == 1_000_000
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ assert spies["record"] == []
diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py
new file mode 100644
index 000000000..9d5fdb190
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py
@@ -0,0 +1,177 @@
+"""Defense-in-depth: image-gen call sites must not let an empty
+``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
+
+The bug repro: an OpenRouter image-gen config ships
+``api_base=""``. The pre-fix call site in
+``image_generation_routes._execute_image_generation`` did
+``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
+silently dropped the empty string. LiteLLM then fell back to
+``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
+and OpenRouter's ``image_generation/transformation`` appended
+``/chat/completions`` to it → 404 ``Resource not found``.
+
+This test pins the post-fix behaviour: with an empty ``api_base`` in
+the config, the call site MUST set ``api_base`` to OpenRouter's public
+URL instead of leaving it unset.
+"""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
+ """The global-config branch (``config_id < 0``) of
+ ``_execute_image_generation`` must apply the resolver and pin
+ ``api_base`` to OpenRouter when the config ships an empty string.
+ """
+ from app.routes import image_generation_routes
+
+ cfg = {
+ "id": -20_001,
+ "name": "GPT Image 1 (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-image-1",
+ "api_key": "sk-or-test",
+ "api_base": "", # the original bug shape
+ "api_version": None,
+ "litellm_params": {},
+ }
+
+ captured: dict = {}
+
+ async def fake_aimage_generation(**kwargs):
+ captured.update(kwargs)
+ return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
+
+ image_gen = MagicMock()
+ image_gen.image_generation_config_id = cfg["id"]
+ image_gen.prompt = "test"
+ image_gen.n = 1
+ image_gen.quality = None
+ image_gen.size = None
+ image_gen.style = None
+ image_gen.response_format = None
+ image_gen.model = None
+
+ search_space = MagicMock()
+ search_space.image_generation_config_id = cfg["id"]
+ session = MagicMock()
+
+ with (
+ patch.object(
+ image_generation_routes,
+ "_get_global_image_gen_config",
+ return_value=cfg,
+ ),
+ patch.object(
+ image_generation_routes,
+ "aimage_generation",
+ side_effect=fake_aimage_generation,
+ ),
+ ):
+ await image_generation_routes._execute_image_generation(
+ session=session, image_gen=image_gen, search_space=search_space
+ )
+
+ # The whole point of the fix: even with empty ``api_base`` in the
+ # config, we forward OpenRouter's public URL so the call doesn't
+ # inherit an Azure endpoint.
+ assert captured.get("api_base") == "https://openrouter.ai/api/v1"
+ assert captured["model"] == "openrouter/openai/gpt-image-1"
+
+
+@pytest.mark.asyncio
+async def test_generate_image_tool_global_sets_api_base_when_config_empty():
+ """Same defense at the agent tool entry point — both surfaces share
+ the same OpenRouter config payloads."""
+ from app.agents.new_chat.tools import generate_image as gi_module
+
+ cfg = {
+ "id": -20_001,
+ "name": "GPT Image 1 (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-image-1",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ "api_version": None,
+ "litellm_params": {},
+ }
+
+ captured: dict = {}
+
+ async def fake_aimage_generation(**kwargs):
+ captured.update(kwargs)
+ response = MagicMock()
+ response.model_dump.return_value = {
+ "data": [{"url": "https://example.com/x.png"}]
+ }
+ response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
+ return response
+
+ search_space = MagicMock()
+ search_space.id = 1
+ search_space.image_generation_config_id = cfg["id"]
+
+ session_cm = AsyncMock()
+ session = AsyncMock()
+ session_cm.__aenter__.return_value = session
+
+ scalars = MagicMock()
+ scalars.first.return_value = search_space
+ exec_result = MagicMock()
+ exec_result.scalars.return_value = scalars
+ session.execute.return_value = exec_result
+ session.add = MagicMock()
+ session.commit = AsyncMock()
+ session.refresh = AsyncMock()
+
+ # ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
+ async def _refresh(obj):
+ obj.id = 1
+
+ session.refresh.side_effect = _refresh
+
+ with (
+ patch.object(gi_module, "shielded_async_session", return_value=session_cm),
+ patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
+ patch.object(
+ gi_module, "aimage_generation", side_effect=fake_aimage_generation
+ ),
+ patch.object(
+ gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
+ ),
+ ):
+ tool = gi_module.create_generate_image_tool(
+ search_space_id=1, db_session=MagicMock()
+ )
+ await tool.ainvoke({"prompt": "a cat", "n": 1})
+
+ assert captured.get("api_base") == "https://openrouter.ai/api/v1"
+ assert captured["model"] == "openrouter/openai/gpt-image-1"
+
+
+def test_image_gen_router_deployment_sets_api_base_when_config_empty():
+ """The Auto-mode router pool must also resolve ``api_base`` when an
+ OpenRouter config ships an empty string. The deployment dict is fed
+ straight to ``litellm.Router``, so a missing ``api_base`` would
+ leak the same way as the direct call sites.
+ """
+ from app.services.image_gen_router_service import ImageGenRouterService
+
+ deployment = ImageGenRouterService._config_to_deployment(
+ {
+ "model_name": "openai/gpt-image-1",
+ "provider": "OPENROUTER",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ }
+ )
+ assert deployment is not None
+ assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
+ assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"
diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py
new file mode 100644
index 000000000..c309ff881
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py
@@ -0,0 +1,226 @@
+"""LLMRouterService pool-filter / rebuild tests.
+
+These tests focus on the *config plumbing* (which configs enter the router
+pool, rebuild resets state correctly). They stub out the underlying
+``litellm.Router`` so we don't need real API keys or network access.
+"""
+
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from app.services.llm_router_service import LLMRouterService
+
+pytestmark = pytest.mark.unit
+
+
+def _fake_yaml_config(
+ *,
+ id: int,
+ model_name: str,
+ billing_tier: str = "free",
+) -> dict:
+ return {
+ "id": id,
+ "name": f"yaml-{id}",
+ "provider": "OPENAI",
+ "model_name": model_name,
+ "api_key": "sk-test",
+ "api_base": "",
+ "billing_tier": billing_tier,
+ "rpm": 100,
+ "tpm": 100_000,
+ "litellm_params": {},
+ }
+
+
+def _fake_openrouter_config(
+ *,
+ id: int,
+ model_name: str,
+ billing_tier: str,
+ router_pool_eligible: bool | None = None,
+) -> dict:
+ """Build a synthetic dynamic-OR config dict for router-pool tests.
+
+ Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
+ out. Callers can override ``router_pool_eligible`` to simulate legacy
+ configs or to regression-test the filter mechanics directly.
+ """
+ if router_pool_eligible is None:
+ router_pool_eligible = billing_tier == "premium"
+ return {
+ "id": id,
+ "name": f"or-{id}",
+ "provider": "OPENROUTER",
+ "model_name": model_name,
+ "api_key": "sk-or-test",
+ "api_base": "",
+ "billing_tier": billing_tier,
+ "rpm": 20 if billing_tier == "free" else 200,
+ "tpm": 100_000 if billing_tier == "free" else 1_000_000,
+ "litellm_params": {},
+ "router_pool_eligible": router_pool_eligible,
+ }
+
+
+def _reset_router_singleton() -> None:
+ instance = LLMRouterService.get_instance()
+ instance._initialized = False
+ instance._router = None
+ instance._model_list = []
+ instance._premium_model_strings = set()
+
+
+def test_router_pool_includes_or_premium_excludes_or_free():
+ """Strategy 3: premium OR joins the pool, free OR stays out.
+
+ Dynamic OpenRouter premium entries opt into load balancing alongside
+ curated YAML configs. Dynamic OR free entries are intentionally kept
+ out because OpenRouter's free tier enforces a single account-global
+ quota bucket that per-deployment router accounting can't represent.
+ """
+ _reset_router_singleton()
+ configs = [
+ _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
+ _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
+ _fake_openrouter_config(
+ id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
+ ),
+ _fake_openrouter_config(
+ id=-10_002,
+ model_name="meta-llama/llama-3.3-70b:free",
+ billing_tier="free",
+ ),
+ ]
+
+ with (
+ patch("app.services.llm_router_service.Router") as mock_router,
+ patch(
+ "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
+ ) as mock_ctx_fb,
+ ):
+ mock_ctx_fb.side_effect = lambda ml: (ml, None)
+ mock_router.return_value = object()
+ LLMRouterService.initialize(configs)
+
+ pool_models = {
+ dep["litellm_params"]["model"]
+ for dep in LLMRouterService.get_instance()._model_list
+ }
+ # YAML premium + YAML free + dynamic OR premium are all in the pool.
+ # Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
+ assert pool_models == {
+ "openai/gpt-4o",
+ "openai/gpt-4o-mini",
+ "openrouter/openai/gpt-4o",
+ }
+
+ prem = LLMRouterService.get_instance()._premium_model_strings
+ # YAML premium is fingerprinted under both its model_string and its
+ # ``base_model`` form (existing behavior we don't want to regress).
+ assert "openai/gpt-4o" in prem
+ # Dynamic OR premium is now fingerprinted as premium so pool-level
+ # calls through the router are billed against premium quota.
+ assert "openrouter/openai/gpt-4o" in prem
+ assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
+ # Dynamic OR free never enters the pool, so it's never counted as premium.
+ assert (
+ LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
+ is False
+ )
+
+
+def test_router_pool_filter_mechanics_respect_override():
+ """The ``router_pool_eligible`` filter itself works independently of tier.
+
+ Regression guard: if a future refactor ever sets the flag False on a
+ premium config (e.g. for maintenance), that config MUST be skipped by
+ ``initialize`` even though its tier is premium.
+ """
+ _reset_router_singleton()
+ configs = [
+ _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
+ _fake_openrouter_config(
+ id=-10_001,
+ model_name="openai/gpt-4o",
+ billing_tier="premium",
+ router_pool_eligible=False, # opt out despite being premium
+ ),
+ ]
+
+ with (
+ patch("app.services.llm_router_service.Router") as mock_router,
+ patch(
+ "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
+ ) as mock_ctx_fb,
+ ):
+ mock_ctx_fb.side_effect = lambda ml: (ml, None)
+ mock_router.return_value = object()
+ LLMRouterService.initialize(configs)
+
+ pool_models = {
+ dep["litellm_params"]["model"]
+ for dep in LLMRouterService.get_instance()._model_list
+ }
+ assert pool_models == {"openai/gpt-4o"}
+ assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
+
+
+def test_rebuild_refreshes_pool_after_configs_change():
+ _reset_router_singleton()
+ configs_v1 = [
+ _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
+ ]
+ configs_v2 = [
+ *configs_v1,
+ _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
+ ]
+
+ with (
+ patch("app.services.llm_router_service.Router") as mock_router,
+ patch(
+ "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
+ ) as mock_ctx_fb,
+ ):
+ mock_ctx_fb.side_effect = lambda ml: (ml, None)
+ mock_router.return_value = object()
+
+ LLMRouterService.initialize(configs_v1)
+ assert len(LLMRouterService.get_instance()._model_list) == 1
+
+ # ``initialize`` should be a no-op here (already initialized).
+ LLMRouterService.initialize(configs_v2)
+ assert len(LLMRouterService.get_instance()._model_list) == 1
+
+ # ``rebuild`` must clear the guard and re-run with the new configs.
+ LLMRouterService.rebuild(configs_v2)
+ assert len(LLMRouterService.get_instance()._model_list) == 2
+
+
+def test_auto_model_pin_candidates_include_dynamic_openrouter():
+ """Dynamic OR configs must remain Auto-mode thread-pin candidates.
+
+ Guards against a future regression where someone adds the
+ ``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
+ """
+ from app.config import config
+ from app.services.auto_model_pin_service import _global_candidates
+
+ or_premium = _fake_openrouter_config(
+ id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
+ )
+ or_free = _fake_openrouter_config(
+ id=-10_002,
+ model_name="meta-llama/llama-3.3-70b:free",
+ billing_tier="free",
+ )
+ original = config.GLOBAL_LLM_CONFIGS
+ try:
+ config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
+ candidate_ids = {c["id"] for c in _global_candidates()}
+ assert candidate_ids == {-10_001, -10_002}
+ finally:
+ config.GLOBAL_LLM_CONFIGS = original
diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
new file mode 100644
index 000000000..88fcf2db3
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
@@ -0,0 +1,380 @@
+"""Unit tests for the dynamic OpenRouter integration."""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.openrouter_integration_service import (
+ _OPENROUTER_DYNAMIC_MARKER,
+ _generate_configs,
+ _openrouter_tier,
+ _stable_config_id,
+)
+
+pytestmark = pytest.mark.unit
+
+
+def _minimal_openrouter_model(
+ *,
+ model_id: str,
+ pricing: dict | None = None,
+ name: str | None = None,
+) -> dict:
+ """Return a synthetic OpenRouter /api/v1/models entry.
+
+ The real API payload includes a lot of fields; we only populate what
+ ``_generate_configs`` actually inspects (architecture, tool support,
+ context, pricing, id).
+ """
+ return {
+ "id": model_id,
+ "name": name or model_id,
+ "architecture": {"output_modalities": ["text"]},
+ "supported_parameters": ["tools"],
+ "context_length": 200_000,
+ "pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
+ }
+
+
+# ---------------------------------------------------------------------------
+# _openrouter_tier
+# ---------------------------------------------------------------------------
+
+
+def test_openrouter_tier_free_suffix():
+ assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
+
+
+def test_openrouter_tier_zero_pricing():
+ model = {
+ "id": "foo/bar",
+ "pricing": {"prompt": "0", "completion": "0"},
+ }
+ assert _openrouter_tier(model) == "free"
+
+
+def test_openrouter_tier_paid():
+ model = {
+ "id": "foo/bar",
+ "pricing": {"prompt": "0.000003", "completion": "0.000015"},
+ }
+ assert _openrouter_tier(model) == "premium"
+
+
+def test_openrouter_tier_missing_pricing_is_premium():
+ assert _openrouter_tier({"id": "foo/bar"}) == "premium"
+ assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
+
+
+# ---------------------------------------------------------------------------
+# _stable_config_id
+# ---------------------------------------------------------------------------
+
+
+def test_stable_config_id_deterministic():
+ taken1: set[int] = set()
+ taken2: set[int] = set()
+ a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
+ b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
+ assert a == b
+ assert a < 0
+
+
+def test_stable_config_id_collision_decrements():
+ """When two model_ids hash to the same slot, the second should decrement."""
+ taken: set[int] = set()
+ a = _stable_config_id("openai/gpt-4o", -10_000, taken)
+ # Force a collision by pre-populating ``taken`` with a slot we know will be
+ # picked.
+ taken_forced = {a}
+ b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
+ assert b != a
+ assert b == a - 1
+ assert b in taken_forced
+
+
+def test_stable_config_id_different_models_different_ids():
+ taken: set[int] = set()
+ ids = {
+ _stable_config_id("openai/gpt-4o", -10_000, taken),
+ _stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
+ _stable_config_id("google/gemini-2.0-flash", -10_000, taken),
+ }
+ assert len(ids) == 3
+
+
+def test_stable_config_id_survives_catalogue_churn():
+ """Removing a model should not shift other models' IDs (the bug we fix)."""
+ taken1: set[int] = set()
+ id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
+ _ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
+ id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
+
+ taken2: set[int] = set()
+ id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
+ id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
+
+ assert id_a1 == id_a2
+ assert id_c1 == id_c2
+
+
+# ---------------------------------------------------------------------------
+# _generate_configs
+# ---------------------------------------------------------------------------
+
+
+_SETTINGS_BASE: dict = {
+ "api_key": "sk-or-test",
+ "id_offset": -10_000,
+ "rpm": 200,
+ "tpm": 1_000_000,
+ "free_rpm": 20,
+ "free_tpm": 100_000,
+ "anonymous_enabled_paid": False,
+ "anonymous_enabled_free": True,
+ "quota_reserve_tokens": 4000,
+}
+
+
+def test_generate_configs_respects_tier():
+ """Premium OR models opt into the router pool; free OR models stay out.
+
+ Strategy-3 split: premium participates in LiteLLM Router load balancing,
+ free stays excluded because OpenRouter enforces a shared global free-tier
+ bucket that per-deployment router accounting can't represent.
+ """
+ raw = [
+ _minimal_openrouter_model(model_id="openai/gpt-4o"),
+ _minimal_openrouter_model(
+ model_id="meta-llama/llama-3.3-70b-instruct:free",
+ pricing={"prompt": "0", "completion": "0"},
+ ),
+ ]
+ cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
+ by_model = {c["model_name"]: c for c in cfgs}
+
+ paid = by_model["openai/gpt-4o"]
+ assert paid["billing_tier"] == "premium"
+ assert paid["rpm"] == 200
+ assert paid["tpm"] == 1_000_000
+ assert paid["anonymous_enabled"] is False
+ assert paid["router_pool_eligible"] is True
+ assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
+
+ free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
+ assert free["billing_tier"] == "free"
+ assert free["rpm"] == 20
+ assert free["tpm"] == 100_000
+ assert free["anonymous_enabled"] is True
+ assert free["router_pool_eligible"] is False
+
+
+def test_generate_configs_excludes_upstream_openrouter_free_router():
+ """OpenRouter's own ``openrouter/free`` meta-router must never become a card.
+
+ The upstream API returns this as a first-class zero-priced model, so
+ without an explicit blocklist entry it would slip through every other
+ filter (text output, tool calling, 200k context, non-Amazon) and land
+ in the selector as a duplicate of the concrete ``:free`` cards. The
+ exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
+ """
+ raw = [
+ _minimal_openrouter_model(model_id="openai/gpt-4o"),
+ _minimal_openrouter_model(
+ model_id="openrouter/free",
+ pricing={"prompt": "0", "completion": "0"},
+ ),
+ ]
+ cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
+ model_names = {c["model_name"] for c in cfgs}
+ assert "openrouter/free" not in model_names
+ assert "openai/gpt-4o" in model_names
+
+
+def test_generate_configs_drops_non_text_and_non_tool_models():
+ raw = [
+ _minimal_openrouter_model(model_id="openai/gpt-4o"),
+ { # image-output model
+ "id": "openai/dall-e",
+ "architecture": {"output_modalities": ["image"]},
+ "supported_parameters": ["tools"],
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.01", "completion": "0.01"},
+ },
+ { # text but no tool calling
+ "id": "openai/completion-only",
+ "architecture": {"output_modalities": ["text"]},
+ "supported_parameters": [],
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.01", "completion": "0.01"},
+ },
+ ]
+ cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
+ model_names = [c["model_name"] for c in cfgs]
+ assert "openai/gpt-4o" in model_names
+ assert "openai/dall-e" not in model_names
+ assert "openai/completion-only" not in model_names
+
+
+# ---------------------------------------------------------------------------
+# _generate_image_gen_configs / _generate_vision_llm_configs
+# ---------------------------------------------------------------------------
+
+
+def test_generate_image_gen_configs_filters_by_image_output():
+ """Only models with ``output_modalities`` containing ``image`` are emitted.
+ Tool-calling and context filters are intentionally NOT applied — image
+ generation has nothing to do with tool calls and context windows.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_image_gen_configs,
+ )
+
+ raw = [
+ # Pure image-gen model (small context, no tools — should still emit).
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {"output_modalities": ["image"]},
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ },
+ # Multi-modal: text+image output (should still emit).
+ {
+ "id": "google/gemini-2.5-flash-image",
+ "architecture": {"output_modalities": ["text", "image"]},
+ "context_length": 1_000_000,
+ "pricing": {"prompt": "0.000001", "completion": "0.000004"},
+ },
+ # Pure text model — must NOT emit.
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {"output_modalities": ["text"]},
+ "context_length": 128_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ ]
+
+ cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
+ model_names = {c["model_name"] for c in cfgs}
+ assert "openai/gpt-image-1" in model_names
+ assert "google/gemini-2.5-flash-image" in model_names
+ assert "openai/gpt-4o" not in model_names
+
+ # Each config must carry ``billing_tier`` for routing in image_generation_routes.
+ for c in cfgs:
+ assert c["billing_tier"] in {"free", "premium"}
+ assert c["provider"] == "OPENROUTER"
+ assert c[_OPENROUTER_DYNAMIC_MARKER] is True
+ # Defense-in-depth: emit the OpenRouter base URL at source so a
+ # downstream call site that forgets ``resolve_api_base`` still
+ # doesn't 404 against an inherited Azure endpoint.
+ assert c["api_base"] == "https://openrouter.ai/api/v1"
+
+
+def test_generate_image_gen_configs_assigns_image_id_offset():
+ """Image configs use a different id_offset (-20000) so their negative
+ IDs don't collide with chat configs (-10000) or vision configs (-30000).
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_image_gen_configs,
+ )
+
+ raw = [
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {"output_modalities": ["image"]},
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ }
+ ]
+ # Don't pass image_id_offset → use the module default (-20000).
+ cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
+ assert all(c["id"] < -20_000 + 1 for c in cfgs)
+ assert all(c["id"] > -29_000_000 for c in cfgs)
+
+
+def test_generate_vision_llm_configs_filters_by_image_input_text_output():
+ """Vision LLMs must accept image input AND emit text — pure image-gen
+ (no text out) and text-only (no image in) models are excluded.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_vision_llm_configs,
+ )
+
+ raw = [
+ # GPT-4o: vision LLM (image in, text out) — must emit.
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "context_length": 128_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ # Pure image generator — image *output*, no text out. Must NOT emit.
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["image"],
+ },
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ },
+ # Pure text model (no image in). Must NOT emit.
+ {
+ "id": "anthropic/claude-3-haiku",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000001", "completion": "0.000005"},
+ },
+ ]
+
+ cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
+ names = {c["model_name"] for c in cfgs}
+ assert names == {"openai/gpt-4o"}
+
+ cfg = cfgs[0]
+ assert cfg["billing_tier"] == "premium"
+ # Pricing carried inline so pricing_registration can register vision
+ # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
+ # is cleared.
+ assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
+ assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
+ assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
+ # Defense-in-depth: emit the OpenRouter base URL at source so a
+ # downstream call site that forgets ``resolve_api_base`` still
+ # doesn't inherit an Azure endpoint.
+ assert cfg["api_base"] == "https://openrouter.ai/api/v1"
+
+
+def test_generate_vision_llm_configs_drops_chat_only_filters():
+ """A small-context vision model that doesn't advertise tool calling is
+ still a valid vision LLM for "describe this image" prompts. The chat
+ filters (``supports_tool_calling``, ``has_sufficient_context``) must
+ NOT be applied to vision emission.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_vision_llm_configs,
+ )
+
+ raw = [
+ {
+ "id": "tiny/vision-mini",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": [], # no tools
+ "context_length": 4_000, # well below MIN_CONTEXT_LENGTH
+ "pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
+ }
+ ]
+
+ cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
+ assert len(cfgs) == 1
+ assert cfgs[0]["model_name"] == "tiny/vision-mini"
diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py
new file mode 100644
index 000000000..4eb1f2295
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py
@@ -0,0 +1,108 @@
+"""Tests for deprecated-key warnings and back-compat in
+``load_openrouter_integration_settings``.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+def _write_yaml(tmp_path: Path, body: str) -> Path:
+ cfg_dir = tmp_path / "app" / "config"
+ cfg_dir.mkdir(parents=True)
+ cfg_path = cfg_dir / "global_llm_config.yaml"
+ cfg_path.write_text(body, encoding="utf-8")
+ return cfg_path
+
+
+def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+
+def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
+ _write_yaml(
+ tmp_path,
+ """
+openrouter_integration:
+ enabled: true
+ api_key: "sk-or-test"
+ billing_tier: "premium"
+""".lstrip(),
+ )
+ _patch_base_dir(monkeypatch, tmp_path)
+
+ from app.config import load_openrouter_integration_settings
+
+ settings = load_openrouter_integration_settings()
+ captured = capsys.readouterr().out
+ assert settings is not None
+ assert "billing_tier is deprecated" in captured
+
+
+def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
+ _write_yaml(
+ tmp_path,
+ """
+openrouter_integration:
+ enabled: true
+ api_key: "sk-or-test"
+ anonymous_enabled: true
+""".lstrip(),
+ )
+ _patch_base_dir(monkeypatch, tmp_path)
+
+ from app.config import load_openrouter_integration_settings
+
+ settings = load_openrouter_integration_settings()
+ captured = capsys.readouterr().out
+ assert settings is not None
+ assert settings["anonymous_enabled_paid"] is True
+ assert settings["anonymous_enabled_free"] is True
+ assert "anonymous_enabled is" in captured
+ assert "deprecated" in captured
+
+
+def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
+ """If both legacy and new keys are present, new keys win (setdefault)."""
+ _write_yaml(
+ tmp_path,
+ """
+openrouter_integration:
+ enabled: true
+ api_key: "sk-or-test"
+ anonymous_enabled: true
+ anonymous_enabled_paid: false
+ anonymous_enabled_free: false
+""".lstrip(),
+ )
+ _patch_base_dir(monkeypatch, tmp_path)
+
+ from app.config import load_openrouter_integration_settings
+
+ settings = load_openrouter_integration_settings()
+ capsys.readouterr()
+ assert settings is not None
+ assert settings["anonymous_enabled_paid"] is False
+ assert settings["anonymous_enabled_free"] is False
+
+
+def test_disabled_integration_returns_none(monkeypatch, tmp_path):
+ _write_yaml(
+ tmp_path,
+ """
+openrouter_integration:
+ enabled: false
+ api_key: "sk-or-test"
+""".lstrip(),
+ )
+ _patch_base_dir(monkeypatch, tmp_path)
+
+ from app.config import load_openrouter_integration_settings
+
+ assert load_openrouter_integration_settings() is None
diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py
new file mode 100644
index 000000000..1c74aa928
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py
@@ -0,0 +1,331 @@
+"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
+
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+
+from app.services.openrouter_integration_service import (
+ OpenRouterIntegrationService,
+)
+from app.services.quality_score import (
+ _HEALTH_FAIL_RATIO_FALLBACK,
+)
+
+pytestmark = pytest.mark.unit
+
+
+def _or_cfg(
+ *,
+ cid: int,
+ model_name: str,
+ tier: str = "premium",
+ static_score: int = 50,
+) -> dict:
+ return {
+ "id": cid,
+ "provider": "OPENROUTER",
+ "model_name": model_name,
+ "billing_tier": tier,
+ "auto_pin_tier": "B" if tier == "premium" else "C",
+ "quality_score_static": static_score,
+ "quality_score_health": None,
+ "quality_score": static_score,
+ "health_gated": False,
+ }
+
+
+class _StubResponse:
+ def __init__(self, *, payload: dict, status_code: int = 200):
+ self._payload = payload
+ self.status_code = status_code
+
+ def raise_for_status(self) -> None:
+ if self.status_code >= 400:
+ raise RuntimeError(f"HTTP {self.status_code}")
+
+ def json(self) -> dict:
+ return self._payload
+
+
+class _StubAsyncClient:
+ """Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
+
+ def __init__(self, responder):
+ self._responder = responder
+ self.requests: list[str] = []
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
+ self.requests.append(url)
+ return self._responder(url)
+
+
+def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
+ """Replace ``httpx.AsyncClient`` for the duration of the test."""
+ client = _StubAsyncClient(responder)
+ monkeypatch.setattr(
+ "app.services.openrouter_integration_service.httpx.AsyncClient",
+ lambda *_args, **_kwargs: client,
+ )
+ return client
+
+
+def _healthy_payload() -> dict:
+ return {
+ "data": {
+ "endpoints": [
+ {
+ "status": 0,
+ "uptime_last_30m": 0.99,
+ "uptime_last_1d": 0.995,
+ "uptime_last_5m": 0.99,
+ }
+ ]
+ }
+ }
+
+
+def _unhealthy_payload() -> dict:
+ return {
+ "data": {
+ "endpoints": [
+ {
+ "status": 0,
+ "uptime_last_30m": 0.55,
+ "uptime_last_1d": 0.62,
+ "uptime_last_5m": 0.50,
+ }
+ ]
+ }
+ }
+
+
+# ---------------------------------------------------------------------------
+# Bounded fan-out + happy path
+# ---------------------------------------------------------------------------
+
+
+async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
+ cfgs = [
+ _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
+ _or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
+ ]
+
+ def responder(url: str) -> _StubResponse:
+ if "anthropic" in url:
+ return _StubResponse(payload=_healthy_payload())
+ return _StubResponse(payload=_unhealthy_payload())
+
+ _patch_async_client(monkeypatch, responder)
+
+ service = OpenRouterIntegrationService()
+ service._settings = {"api_key": ""}
+ await service._enrich_health(cfgs)
+
+ healthy = next(c for c in cfgs if c["id"] == -1)
+ gated = next(c for c in cfgs if c["id"] == -2)
+
+ assert healthy["health_gated"] is False
+ assert healthy["quality_score_health"] is not None
+ assert healthy["quality_score"] >= healthy["quality_score_static"]
+
+ assert gated["health_gated"] is True
+ assert gated["quality_score"] == gated["quality_score_static"]
+
+
+async def test_enrich_health_only_touches_or_provider(monkeypatch):
+ """YAML cfgs that aren't OPENROUTER must be skipped entirely."""
+ yaml_cfg = {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score_static": 80,
+ "quality_score": 80,
+ "health_gated": False,
+ }
+ or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
+
+ requests: list[str] = []
+
+ def responder(url: str) -> _StubResponse:
+ requests.append(url)
+ return _StubResponse(payload=_healthy_payload())
+
+ _patch_async_client(monkeypatch, responder)
+
+ service = OpenRouterIntegrationService()
+ service._settings = {}
+ await service._enrich_health([yaml_cfg, or_cfg])
+
+ assert all("anthropic/claude-haiku" in r for r in requests)
+ # YAML cfg is untouched.
+ assert yaml_cfg["quality_score"] == 80
+ assert yaml_cfg["health_gated"] is False
+
+
+# ---------------------------------------------------------------------------
+# Failure ratio fallback
+# ---------------------------------------------------------------------------
+
+
+async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
+ monkeypatch,
+):
+ """If >= 25% of fetches fail, keep last-good cache instead of writing
+ partial data."""
+ cfgs = [
+ _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
+ _or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
+ _or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
+ _or_cfg(cid=-4, model_name="venice/something", static_score=50),
+ ]
+
+ service = OpenRouterIntegrationService()
+ service._settings = {}
+ # Pre-seed last-good cache with a known-healthy snapshot.
+ service._health_cache = {
+ "anthropic/claude-haiku": {"gated": False, "score": 95.0},
+ }
+
+ def all_fail(_url: str) -> _StubResponse:
+ return _StubResponse(payload={}, status_code=500)
+
+ _patch_async_client(monkeypatch, all_fail)
+ await service._enrich_health(cfgs)
+
+ # Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
+ cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
+ assert cached_hit["quality_score_health"] == 95.0
+ assert cached_hit["health_gated"] is False
+ # Confirm the threshold constant we're testing against is real.
+ assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
+
+
+async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
+ monkeypatch,
+):
+ """If a fetch fails and there's no last-good cache, the cfg keeps its
+ static-only ``quality_score`` and is *not* gated by default."""
+ cfgs = [
+ _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
+ ]
+
+ def fail(_url: str) -> _StubResponse:
+ return _StubResponse(payload={}, status_code=500)
+
+ _patch_async_client(monkeypatch, fail)
+
+ service = OpenRouterIntegrationService()
+ service._settings = {}
+ await service._enrich_health(cfgs)
+
+ cfg = cfgs[0]
+ assert cfg["health_gated"] is False
+ assert cfg["quality_score"] == cfg["quality_score_static"]
+ assert cfg["quality_score_health"] is None
+
+
+# ---------------------------------------------------------------------------
+# Last-good cache: success populates, next failure reuses
+# ---------------------------------------------------------------------------
+
+
+async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
+ monkeypatch,
+):
+ cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
+
+ service = OpenRouterIntegrationService()
+ service._settings = {}
+
+ def healthy(_url: str) -> _StubResponse:
+ return _StubResponse(payload=_healthy_payload())
+
+ _patch_async_client(monkeypatch, healthy)
+ await service._enrich_health([cfg])
+
+ assert "anthropic/claude-haiku" in service._health_cache
+ cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
+ assert cached_score is not None
+
+ # Next cycle: enough other healthy cfgs so failure ratio stays below
+ # the 25% threshold even when this one fails individually.
+ other_cfgs = [
+ _or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
+ for i in range(10)
+ ]
+ cfg["quality_score_health"] = None
+ cfg["quality_score"] = cfg["quality_score_static"]
+
+ def mixed(url: str) -> _StubResponse:
+ if "anthropic" in url:
+ return _StubResponse(payload={}, status_code=500)
+ return _StubResponse(payload=_healthy_payload())
+
+ _patch_async_client(monkeypatch, mixed)
+ await service._enrich_health([cfg, *other_cfgs])
+
+ assert cfg["quality_score_health"] == cached_score
+ assert cfg["health_gated"] is False
+
+
+# ---------------------------------------------------------------------------
+# Bounded fan-out: respects top-N caps
+# ---------------------------------------------------------------------------
+
+
+async def test_enrich_health_bounds_premium_fanout(monkeypatch):
+ """Top-N premium cap is honoured even when many cfgs are present."""
+ from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
+
+ cfgs = [
+ _or_cfg(
+ cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
+ )
+ for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
+ ]
+
+ seen: list[str] = []
+
+ def responder(url: str) -> _StubResponse:
+ seen.append(url)
+ return _StubResponse(payload=_healthy_payload())
+
+ _patch_async_client(monkeypatch, responder)
+
+ service = OpenRouterIntegrationService()
+ service._settings = {}
+ await service._enrich_health(cfgs)
+
+ assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
+
+
+async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
+ """When the catalogue has no OR cfgs at all, no HTTP calls fire."""
+ yaml_cfg: dict[str, Any] = {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "billing_tier": "premium",
+ }
+ requests: list[str] = []
+
+ def responder(url: str) -> _StubResponse:
+ requests.append(url)
+ return _StubResponse(payload=_healthy_payload())
+
+ _patch_async_client(monkeypatch, responder)
+
+ service = OpenRouterIntegrationService()
+ service._settings = {}
+ await service._enrich_health([yaml_cfg])
+ assert requests == []
diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py
new file mode 100644
index 000000000..e97250ff2
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py
@@ -0,0 +1,447 @@
+"""Pricing registration unit tests.
+
+The pricing-registration module is what makes ``response_cost`` populate
+correctly for OpenRouter dynamic models and operator-defined Azure
+deployments — both of which LiteLLM doesn't natively know about. The tests
+exercise:
+
+* The alias generators emit every shape that LiteLLM's cost-callback might
+ use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
+ ``provider/base_model``, ``provider/model_name``, plus the special
+ ``azure_openai`` → ``azure`` normalisation).
+* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
+ with the right alias set and pricing values per provider.
+* Configs without a resolvable pair of cost values are skipped — never
+ registered as zero, since that would override pricing LiteLLM might
+ already know natively.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Alias generators
+# ---------------------------------------------------------------------------
+
+
+def test_openrouter_alias_set_includes_prefixed_and_bare():
+ from app.services.pricing_registration import _alias_set_for_openrouter
+
+ aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
+ assert aliases == [
+ "openrouter/anthropic/claude-3-5-sonnet",
+ "anthropic/claude-3-5-sonnet",
+ ]
+
+
+def test_openrouter_alias_set_dedupes():
+ """If the model id is already prefixed with ``openrouter/``, the alias
+ set must not contain duplicates that would re-register the same key
+ twice.
+ """
+ from app.services.pricing_registration import _alias_set_for_openrouter
+
+ aliases = _alias_set_for_openrouter("openrouter/foo")
+ # The bare and prefixed variants compute to the same string here, so we
+ # at minimum require uniqueness.
+ assert len(aliases) == len(set(aliases))
+
+
+def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
+ """``azure_openai`` (our YAML provider slug) must register under
+ ``azure/`` so the LiteLLM Router's deployment-resolution path
+ (which uses provider ``azure``) finds the pricing too.
+ """
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="AZURE_OPENAI",
+ model_name="gpt-5.4",
+ base_model="gpt-5.4",
+ )
+ assert "gpt-5.4" in aliases
+ assert "azure_openai/gpt-5.4" in aliases
+ assert "azure/gpt-5.4" in aliases
+
+
+def test_yaml_alias_set_distinguishes_model_name_and_base_model():
+ """When ``model_name`` differs from ``base_model`` (operator labelled a
+ deployment), both must appear in the alias set since either may surface
+ in callbacks depending on the call path.
+ """
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="OPENAI",
+ model_name="my-deployment-label",
+ base_model="gpt-4o",
+ )
+ assert "gpt-4o" in aliases
+ assert "openai/gpt-4o" in aliases
+ assert "my-deployment-label" in aliases
+ assert "openai/my-deployment-label" in aliases
+
+
+def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="",
+ model_name="foo",
+ base_model="bar",
+ )
+ assert "bar" in aliases
+ assert "foo" in aliases
+ assert all("/" not in a for a in aliases)
+
+
+# ---------------------------------------------------------------------------
+# register_pricing_from_global_configs
+# ---------------------------------------------------------------------------
+
+
+class _RegistrationSpy:
+ """Captures the dicts passed to ``litellm.register_model``.
+
+ Many calls may go through; we just record them all and let tests assert
+ against the union.
+ """
+
+ def __init__(self) -> None:
+ self.calls: list[dict[str, Any]] = []
+
+ def __call__(self, payload: dict[str, Any]) -> None:
+ self.calls.append(payload)
+
+ @property
+ def all_keys(self) -> set[str]:
+ keys: set[str] = set()
+ for payload in self.calls:
+ keys.update(payload.keys())
+ return keys
+
+
+def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
+ spy = _RegistrationSpy()
+ monkeypatch.setattr(
+ "app.services.pricing_registration.litellm.register_model",
+ spy,
+ raising=False,
+ )
+ return spy
+
+
+def _patch_openrouter_pricing(
+ monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
+) -> None:
+ """Pretend the OpenRouter integration is initialised with ``mapping``."""
+
+ class _Stub:
+ def get_raw_pricing(self) -> dict[str, dict[str, str]]:
+ return mapping
+
+ class _StubService:
+ @classmethod
+ def is_initialized(cls) -> bool:
+ return True
+
+ @classmethod
+ def get_instance(cls) -> _Stub:
+ return _Stub()
+
+ monkeypatch.setattr(
+ "app.services.openrouter_integration_service.OpenRouterIntegrationService",
+ _StubService,
+ raising=False,
+ )
+
+
+def test_openrouter_models_register_under_aliases(monkeypatch):
+ """An OpenRouter config whose ``model_name`` is in the cached raw
+ pricing map is registered under both ``openrouter/X`` and bare ``X``.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {
+ "anthropic/claude-3-5-sonnet": {
+ "prompt": "0.000003",
+ "completion": "0.000015",
+ }
+ },
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
+ assert "anthropic/claude-3-5-sonnet" in spy.all_keys
+ # Costs are float-converted from the raw OpenRouter strings.
+ payload = spy.calls[0]
+ assert payload["openrouter/anthropic/claude-3-5-sonnet"][
+ "input_cost_per_token"
+ ] == pytest.approx(3e-6)
+ assert payload["openrouter/anthropic/claude-3-5-sonnet"][
+ "output_cost_per_token"
+ ] == pytest.approx(15e-6)
+ assert (
+ payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
+ == "openrouter"
+ )
+
+
+def test_yaml_override_registers_under_alias_set(monkeypatch):
+ """Operator-declared ``input_cost_per_token`` /
+ ``output_cost_per_token`` on a YAML config registers under every
+ alias the YAML alias generator produces — including the ``azure/``
+ normalisation for ``azure_openai`` providers.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.4",
+ "litellm_params": {
+ "base_model": "gpt-5.4",
+ "input_cost_per_token": 2e-6,
+ "output_cost_per_token": 8e-6,
+ },
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ keys = spy.all_keys
+ assert "gpt-5.4" in keys
+ assert "azure_openai/gpt-5.4" in keys
+ assert "azure/gpt-5.4" in keys
+
+ payload = spy.calls[0]
+ entry = payload["gpt-5.4"]
+ assert entry["input_cost_per_token"] == pytest.approx(2e-6)
+ assert entry["output_cost_per_token"] == pytest.approx(8e-6)
+ assert entry["litellm_provider"] == "azure"
+
+
+def test_no_override_means_no_registration(monkeypatch):
+ """A YAML config that *omits* both pricing fields must NOT be registered
+ — registering as zero would override LiteLLM's native pricing for the
+ ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
+ bill drop to $0. Fail-safe is "skip and warn", not "register zero".
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "litellm_params": {"base_model": "gpt-4o"},
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert spy.calls == []
+
+
+def test_openrouter_skipped_when_pricing_missing(monkeypatch):
+ """If the OpenRouter raw-pricing cache doesn't carry an entry for a
+ configured model (network blip during refresh, model added later, etc.),
+ we skip it rather than registering zero pricing.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert spy.calls == []
+
+
+def test_register_continues_after_individual_failure(monkeypatch, caplog):
+ """A single bad ``register_model`` call (e.g. raising LiteLLM error)
+ must not abort registration of the remaining configs.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
+ successful_calls: list[dict[str, Any]] = []
+
+ def _maybe_fail(payload: dict[str, Any]) -> None:
+ if any(k in failing_keys for k in payload):
+ raise RuntimeError("boom")
+ successful_calls.append(payload)
+
+ monkeypatch.setattr(
+ "app.services.pricing_registration.litellm.register_model",
+ _maybe_fail,
+ raising=False,
+ )
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {
+ "anthropic/claude-3-5-sonnet": {
+ "prompt": "0.000003",
+ "completion": "0.000015",
+ }
+ },
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ },
+ {
+ "id": 2,
+ "provider": "OPENAI",
+ "model_name": "custom-deployment",
+ "litellm_params": {
+ "base_model": "custom-deployment",
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 2e-6,
+ },
+ },
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ # The good config still registered.
+ assert any("custom-deployment" in payload for payload in successful_calls)
+
+
+def test_vision_configs_registered_with_chat_shape(monkeypatch):
+ """``register_pricing_from_global_configs`` walks
+ ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
+ calls (during indexing) bill correctly. Vision configs use the same
+ chat-shape token prices, but image-gen pricing is intentionally NOT
+ registered here (handled via ``response_cost`` in LiteLLM).
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
+ )
+
+ # No chat configs — only vision. Proves the vision walk is a separate
+ # iteration, not piggy-backed on the chat list.
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_VISION_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "billing_tier": "premium",
+ "input_cost_per_token": 5e-6,
+ "output_cost_per_token": 15e-6,
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/openai/gpt-4o" in spy.all_keys
+ payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
+ assert payload_value["mode"] == "chat"
+ assert payload_value["litellm_provider"] == "openrouter"
+ assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
+ assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
+
+
+def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
+ """If the OpenRouter pricing cache misses a vision model (different
+ catalogue surface), the vision walk falls back to inline
+ ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_VISION_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash",
+ "billing_tier": "premium",
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 4e-6,
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py
new file mode 100644
index 000000000..12cd0a3d5
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py
@@ -0,0 +1,107 @@
+"""Unit tests for the shared ``api_base`` resolver.
+
+The cascade exists so vision and image-gen call sites can't silently
+inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
+when an OpenRouter / Groq / etc. config ships an empty string. See
+``provider_api_base`` module docstring for the original repro
+(OpenRouter image-gen 404-ing against an Azure endpoint).
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.provider_api_base import (
+ PROVIDER_DEFAULT_API_BASE,
+ PROVIDER_KEY_DEFAULT_API_BASE,
+ resolve_api_base,
+)
+
+pytestmark = pytest.mark.unit
+
+
+def test_config_value_wins_over_defaults():
+ """A non-empty config value is always returned verbatim, even when the
+ provider has a default — the operator gets the last word."""
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base="https://my-openrouter-mirror.example.com/v1",
+ )
+ assert result == "https://my-openrouter-mirror.example.com/v1"
+
+
+def test_provider_key_default_when_config_missing():
+ """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
+ base URL — the provider-key map must take precedence over the prefix
+ map so DeepSeek requests don't go to OpenAI."""
+ result = resolve_api_base(
+ provider="DEEPSEEK",
+ provider_prefix="openai",
+ config_api_base=None,
+ )
+ assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
+
+
+def test_provider_prefix_default_when_no_key_default():
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base=None,
+ )
+ assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
+
+
+def test_unknown_provider_returns_none():
+ """When neither map matches we return ``None`` so the caller can let
+ LiteLLM apply its own provider-integration default (Azure deployment
+ URL, custom-provider URL, etc.)."""
+ result = resolve_api_base(
+ provider="SOMETHING_NEW",
+ provider_prefix="something_new",
+ config_api_base=None,
+ )
+ assert result is None
+
+
+def test_empty_string_config_treated_as_missing():
+ """The original bug: OpenRouter dynamic configs ship ``api_base=""``
+ and downstream call sites use ``if cfg.get("api_base"):`` — empty
+ strings are falsy in Python but the cascade has to step in anyway."""
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base="",
+ )
+ assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
+
+
+def test_whitespace_only_config_treated_as_missing():
+ """A config value of ``" "`` is a configuration mistake — treat it
+ as missing instead of forwarding whitespace to LiteLLM (which would
+ almost certainly 404)."""
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base=" ",
+ )
+ assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
+
+
+def test_provider_case_insensitive():
+ """Some call sites pass the provider lowercase (DB enum value), others
+ uppercase (YAML key). Both must resolve."""
+ upper = resolve_api_base(
+ provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
+ )
+ lower = resolve_api_base(
+ provider="deepseek", provider_prefix="openai", config_api_base=None
+ )
+ assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
+
+
+def test_all_inputs_none_returns_none():
+ assert (
+ resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
+ is None
+ )
diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py
new file mode 100644
index 000000000..aac88977f
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py
@@ -0,0 +1,244 @@
+"""Unit tests for the shared chat-image capability resolver.
+
+Two resolvers, two intents:
+
+- ``derive_supports_image_input`` — best-effort True for the catalog and
+ selector. Default-allow on unknown / unmapped models. The streaming
+ task safety net never sees this value directly.
+
+- ``is_known_text_only_chat_model`` — strict opt-out for the safety net.
+ Returns True only when LiteLLM's model map *explicitly* sets
+ ``supports_vision=False``. Anything else (missing key, exception,
+ True) returns False so the request flows through to the provider.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.provider_capabilities import (
+ derive_supports_image_input,
+ is_known_text_only_chat_model,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# derive_supports_image_input — OpenRouter modalities path (authoritative)
+# ---------------------------------------------------------------------------
+
+
+def test_or_modalities_with_image_returns_true():
+ assert (
+ derive_supports_image_input(
+ provider="OPENROUTER",
+ model_name="openai/gpt-4o",
+ openrouter_input_modalities=["text", "image"],
+ )
+ is True
+ )
+
+
+def test_or_modalities_text_only_returns_false():
+ assert (
+ derive_supports_image_input(
+ provider="OPENROUTER",
+ model_name="deepseek/deepseek-v3.2-exp",
+ openrouter_input_modalities=["text"],
+ )
+ is False
+ )
+
+
+def test_or_modalities_empty_list_returns_false():
+ """OR explicitly publishing an empty modality list is a definitive
+ 'no inputs at all' signal — treat as False rather than falling back
+ to LiteLLM."""
+ assert (
+ derive_supports_image_input(
+ provider="OPENROUTER",
+ model_name="weird/empty-modalities",
+ openrouter_input_modalities=[],
+ )
+ is False
+ )
+
+
+def test_or_modalities_none_falls_through_to_litellm():
+ """``None`` (missing key) is *not* a definitive signal — fall through
+ to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
+ assert (
+ derive_supports_image_input(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ openrouter_input_modalities=None,
+ )
+ is True
+ )
+
+
+# ---------------------------------------------------------------------------
+# derive_supports_image_input — LiteLLM model-map path
+# ---------------------------------------------------------------------------
+
+
+def test_litellm_known_vision_model_returns_true():
+ assert (
+ derive_supports_image_input(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is True
+ )
+
+
+def test_litellm_base_model_wins_over_model_name():
+ """Azure-style entries pass model_name=deployment_id and put the
+ canonical sku in litellm_params.base_model. The resolver must
+ consult base_model first or the deployment id (which LiteLLM
+ doesn't know) would shadow the real capability."""
+ assert (
+ derive_supports_image_input(
+ provider="AZURE_OPENAI",
+ model_name="my-azure-deployment-id",
+ base_model="gpt-4o",
+ )
+ is True
+ )
+
+
+def test_litellm_unknown_model_default_allows():
+ """Default-allow on unknown — the safety net is the actual block."""
+ assert (
+ derive_supports_image_input(
+ provider="CUSTOM",
+ model_name="brand-new-model-x9-unmapped",
+ custom_provider="brand_new_proxy",
+ )
+ is True
+ )
+
+
+def test_litellm_known_text_only_returns_false():
+ """A model that LiteLLM explicitly knows is text-only resolves to
+ False even via the catalog resolver. ``deepseek-chat`` (the
+ DeepSeek-V3 chat sku) is in the map without supports_vision and
+ LiteLLM's `supports_vision` returns False."""
+ # Sanity: confirm the helper's negative path. We use a small model
+ # known not to support vision per the map.
+ result = derive_supports_image_input(
+ provider="DEEPSEEK",
+ model_name="deepseek-chat",
+ )
+ # We accept either False (LiteLLM said explicit no) or True
+ # (default-allow if the entry isn't mapped on this version) — the
+ # invariant is that the resolver never *raises* on a known-text-only
+ # provider/model. The behaviour-binding assertion lives in
+ # ``test_is_known_text_only_chat_model_explicit_false`` below.
+ assert isinstance(result, bool)
+
+
+# ---------------------------------------------------------------------------
+# is_known_text_only_chat_model — strict opt-out semantics
+# ---------------------------------------------------------------------------
+
+
+def test_is_known_text_only_returns_false_for_vision_model():
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_false_for_unknown_model():
+ """Strict opt-out: missing from the map ≠ text-only. The safety net
+ must NOT fire for an unmapped model — that's the regression we're
+ fixing."""
+ assert (
+ is_known_text_only_chat_model(
+ provider="CUSTOM",
+ model_name="brand-new-model-x9-unmapped",
+ custom_provider="brand_new_proxy",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
+ """LiteLLM's ``get_model_info`` raises freely on parse errors. The
+ helper swallows the exception and returns False so the safety net
+ doesn't fire on a transient lookup failure."""
+ import app.services.provider_capabilities as pc
+
+ def _raise(**_kwargs):
+ raise ValueError("intentional test failure")
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
+ """Stub LiteLLM's ``get_model_info`` to return an explicit False so
+ we exercise the opt-out path deterministically. Using a stub keeps
+ the test stable across LiteLLM map updates."""
+ import app.services.provider_capabilities as pc
+
+ def _info(**_kwargs):
+ return {"supports_vision": False, "max_input_tokens": 8192}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="any-model",
+ )
+ is True
+ )
+
+
+def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
+ import app.services.provider_capabilities as pc
+
+ def _info(**_kwargs):
+ return {"supports_vision": True}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="any-model",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
+ """A model entry without ``supports_vision`` at all is treated as
+ 'unknown' — strict opt-out means False."""
+ import app.services.provider_capabilities as pc
+
+ def _info(**_kwargs):
+ return {"max_input_tokens": 8192} # no supports_vision
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="any-model",
+ )
+ is False
+ )
diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py
new file mode 100644
index 000000000..6fbc8fd62
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_quality_score.py
@@ -0,0 +1,345 @@
+"""Unit tests for the Auto (Fastest) quality scoring module."""
+
+from __future__ import annotations
+
+import time
+
+import pytest
+
+from app.services.quality_score import (
+ _HEALTH_GATE_UPTIME_PCT,
+ _OPERATOR_TRUST_BONUS,
+ aggregate_health,
+ capabilities_signal,
+ context_signal,
+ created_recency_signal,
+ pricing_band,
+ slug_penalty,
+ static_score_or,
+ static_score_yaml,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# created_recency_signal
+# ---------------------------------------------------------------------------
+
+
+def test_created_recency_signal_recent_model_scores_high():
+ now = 1_750_000_000 # ~mid-2025
+ one_month_ago = now - (30 * 86_400)
+ assert created_recency_signal(one_month_ago, now) == 20
+
+
+def test_created_recency_signal_old_model_scores_zero():
+ now = 1_750_000_000
+ five_years_ago = now - (5 * 365 * 86_400)
+ assert created_recency_signal(five_years_ago, now) == 0
+
+
+def test_created_recency_signal_missing_timestamp_is_neutral():
+ now = 1_750_000_000
+ assert created_recency_signal(None, now) == 0
+ assert created_recency_signal(0, now) == 0
+
+
+def test_created_recency_signal_monotonic_decay():
+ now = 1_750_000_000
+ scores = [
+ created_recency_signal(now - days * 86_400, now)
+ for days in (30, 120, 300, 500, 700, 1000, 1500)
+ ]
+ assert scores == sorted(scores, reverse=True)
+
+
+# ---------------------------------------------------------------------------
+# pricing_band
+# ---------------------------------------------------------------------------
+
+
+def test_pricing_band_free_returns_zero():
+ assert pricing_band("0", "0") == 0
+ assert pricing_band(0.0, 0.0) == 0
+ assert pricing_band(None, None) == 0
+
+
+def test_pricing_band_handles_unparseable():
+ assert pricing_band("not-a-number", "0") == 0
+ assert pricing_band({}, []) == 0 # type: ignore[arg-type]
+
+
+def test_pricing_band_premium_tiers_increase_with_price():
+ cheap = pricing_band("0.0000003", "0.0000005")
+ mid = pricing_band("0.000003", "0.000015")
+ flagship = pricing_band("0.00001", "0.00005")
+ assert 0 < cheap < mid < flagship
+
+
+# ---------------------------------------------------------------------------
+# context_signal
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "ctx,expected",
+ [
+ (1_500_000, 10),
+ (1_000_000, 10),
+ (500_000, 8),
+ (200_000, 6),
+ (128_000, 4),
+ (100_000, 2),
+ (50_000, 0),
+ (0, 0),
+ (None, 0),
+ ],
+)
+def test_context_signal_bands(ctx, expected):
+ assert context_signal(ctx) == expected
+
+
+# ---------------------------------------------------------------------------
+# capabilities_signal
+# ---------------------------------------------------------------------------
+
+
+def test_capabilities_signal_caps_at_five():
+ assert (
+ capabilities_signal(
+ ["tools", "structured_outputs", "reasoning", "include_reasoning"]
+ )
+ <= 5
+ )
+
+
+def test_capabilities_signal_tools_only():
+ assert capabilities_signal(["tools"]) == 2
+
+
+def test_capabilities_signal_empty():
+ assert capabilities_signal(None) == 0
+ assert capabilities_signal([]) == 0
+
+
+# ---------------------------------------------------------------------------
+# slug_penalty
+# ---------------------------------------------------------------------------
+
+
+def test_slug_penalty_demotes_tiny_models():
+ assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
+ assert slug_penalty("liquid/lfm-7b") < 0
+ assert slug_penalty("google/gemma-3n-e4b-it") < 0
+
+
+def test_slug_penalty_skips_capable_mini_nano_lite_models():
+ """Critical Option C+ regression: don't penalise modern frontier
+ models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
+ assert slug_penalty("openai/gpt-5-mini") == 0
+ assert slug_penalty("openai/gpt-5-nano") == 0
+ assert slug_penalty("google/gemini-2.5-flash-lite") == 0
+ assert slug_penalty("anthropic/claude-haiku-4.5") == 0
+
+
+def test_slug_penalty_demotes_legacy_variants():
+ assert slug_penalty("openai/o1-preview") < 0
+ assert slug_penalty("foo/bar-base") < 0
+ assert slug_penalty("foo/bar-distill") < 0
+
+
+def test_slug_penalty_empty_input():
+ assert slug_penalty("") == 0
+
+
+# ---------------------------------------------------------------------------
+# static_score_or
+# ---------------------------------------------------------------------------
+
+
+def _or_model(
+ *,
+ model_id: str,
+ created: int | None = None,
+ prompt: str = "0.000003",
+ completion: str = "0.000015",
+ context: int = 200_000,
+ params: list[str] | None = None,
+) -> dict:
+ return {
+ "id": model_id,
+ "created": created,
+ "pricing": {"prompt": prompt, "completion": completion},
+ "context_length": context,
+ "supported_parameters": params if params is not None else ["tools"],
+ }
+
+
+def test_static_score_or_frontier_premium_beats_free_tiny():
+ now = 1_750_000_000
+ frontier = _or_model(
+ model_id="openai/gpt-5",
+ created=now - (60 * 86_400),
+ prompt="0.000005",
+ completion="0.000020",
+ context=400_000,
+ params=["tools", "structured_outputs", "reasoning"],
+ )
+ tiny_free = _or_model(
+ model_id="meta-llama/llama-3.2-1b-instruct:free",
+ created=now - (5 * 365 * 86_400),
+ prompt="0",
+ completion="0",
+ context=128_000,
+ params=["tools"],
+ )
+ assert static_score_or(frontier, now_ts=now) > static_score_or(
+ tiny_free, now_ts=now
+ )
+
+
+def test_static_score_or_score_is_clamped_0_to_100():
+ now = int(time.time())
+ score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
+ assert 0 <= score <= 100
+
+
+def test_static_score_or_unknown_provider_is_neutral_not_zero():
+ now = int(time.time())
+ score = static_score_or(
+ _or_model(model_id="some-new-lab/some-model"),
+ now_ts=now,
+ )
+ assert score > 0
+
+
+def test_static_score_or_recent_release_beats_year_old_same_provider():
+ now = 1_750_000_000
+ fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
+ old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
+ assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
+
+
+# ---------------------------------------------------------------------------
+# static_score_yaml
+# ---------------------------------------------------------------------------
+
+
+def test_static_score_yaml_includes_operator_bonus():
+ cfg = {
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "litellm_params": {"base_model": "azure/gpt-5"},
+ }
+ score = static_score_yaml(cfg)
+ assert score >= _OPERATOR_TRUST_BONUS
+
+
+def test_static_score_yaml_unknown_provider_still_carries_bonus():
+ cfg = {
+ "provider": "SOME_NEW_PROVIDER",
+ "model_name": "weird-model",
+ }
+ score = static_score_yaml(cfg)
+ assert score >= _OPERATOR_TRUST_BONUS
+
+
+def test_static_score_yaml_clamped_0_to_100():
+ cfg = {
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5",
+ "litellm_params": {"base_model": "azure/gpt-5"},
+ }
+ assert 0 <= static_score_yaml(cfg) <= 100
+
+
+# ---------------------------------------------------------------------------
+# aggregate_health
+# ---------------------------------------------------------------------------
+
+
+def test_aggregate_health_gates_when_uptime_below_threshold():
+ """Live data showed Venice-routed cfgs at 53-68%; this guards that the
+ 90% gate excludes them."""
+ venice_endpoints = [
+ {
+ "status": 0,
+ "uptime_last_30m": 0.55,
+ "uptime_last_1d": 0.60,
+ "uptime_last_5m": 0.50,
+ },
+ {
+ "status": 0,
+ "uptime_last_30m": 0.65,
+ "uptime_last_1d": 0.68,
+ "uptime_last_5m": 0.62,
+ },
+ ]
+ gated, score = aggregate_health(venice_endpoints)
+ assert gated is True
+ assert score is None
+
+
+def test_aggregate_health_passes_for_healthy_provider():
+ healthy = [
+ {
+ "status": 0,
+ "uptime_last_30m": 0.99,
+ "uptime_last_1d": 0.995,
+ "uptime_last_5m": 0.99,
+ },
+ ]
+ gated, score = aggregate_health(healthy)
+ assert gated is False
+ assert score is not None
+ assert score >= _HEALTH_GATE_UPTIME_PCT
+
+
+def test_aggregate_health_picks_best_endpoint_across_multiple():
+ """Multi-endpoint aggregation should reward the best non-null uptime."""
+ mixed = [
+ {"status": 0, "uptime_last_30m": 0.55},
+ {"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
+ ]
+ gated, score = aggregate_health(mixed)
+ assert gated is False
+ assert score is not None
+
+
+def test_aggregate_health_empty_endpoints_gated():
+ gated, score = aggregate_health([])
+ assert gated is True
+ assert score is None
+
+
+def test_aggregate_health_no_status_zero_gated():
+ """Even with high uptime, no OK status means the cfg is broken upstream."""
+ endpoints = [
+ {"status": 1, "uptime_last_30m": 0.99},
+ {"status": 2, "uptime_last_30m": 0.98},
+ ]
+ gated, score = aggregate_health(endpoints)
+ assert gated is True
+ assert score is None
+
+
+def test_aggregate_health_all_uptime_null_gated():
+ endpoints = [
+ {"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
+ ]
+ gated, score = aggregate_health(endpoints)
+ assert gated is True
+ assert score is None
+
+
+def test_aggregate_health_pct_normalisation():
+ """OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
+ percentages. Both should reach the same gate decision."""
+ fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
+ pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
+ g1, s1 = aggregate_health(fraction_form)
+ g2, s2 = aggregate_health(pct_form)
+ assert g1 == g2 == False # noqa: E712
+ assert s1 is not None and s2 is not None
+ assert abs(s1 - s2) < 0.5
diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
new file mode 100644
index 000000000..9e35b6f9c
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
@@ -0,0 +1,157 @@
+"""Unit tests for ``QuotaCheckedVisionLLM``.
+
+Validates that:
+
+* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
+ enforcement) and forwards the inner LLM's response on success.
+* The wrapper proxies non-overridden attributes to the inner LLM
+ (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
+ still work without quota gating (they're not used in indexing today).
+* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
+ bubbles it up — the ETL pipeline catches that and falls back to OCR.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+class _FakeInnerLLM:
+ """Stand-in for ``langchain_litellm.ChatLiteLLM``."""
+
+ def __init__(self, response: Any = "OCR'd content") -> None:
+ self._response = response
+ self.ainvoke_calls: list[Any] = []
+
+ async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
+ self.ainvoke_calls.append(input)
+ return self._response
+
+ def some_other_method(self, x: int) -> int:
+ return x * 2
+
+
+@contextlib.asynccontextmanager
+async def _passthrough_billable_call(**_kwargs):
+ """Stand-in for billable_call that always allows the call to run."""
+
+ class _Acc:
+ total_cost_micros = 0
+ total_prompt_tokens = 0
+ total_completion_tokens = 0
+ grand_total = 0
+ calls: list[Any] = []
+
+ def per_message_summary(self) -> dict[str, dict[str, int]]:
+ return {}
+
+ yield _Acc()
+
+
+@pytest.mark.asyncio
+async def test_ainvoke_routes_through_billable_call(monkeypatch):
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ captured_kwargs: list[dict[str, Any]] = []
+
+ @contextlib.asynccontextmanager
+ async def _spy_billable_call(**kwargs):
+ captured_kwargs.append(kwargs)
+ async with _passthrough_billable_call() as acc:
+ yield acc
+
+ monkeypatch.setattr(
+ "app.services.quota_checked_vision_llm.billable_call",
+ _spy_billable_call,
+ raising=False,
+ )
+
+ inner = _FakeInnerLLM(response="A red apple on a white table")
+ user_id = uuid4()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=user_id,
+ search_space_id=99,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ result = await wrapper.ainvoke([{"text": "what is this?"}])
+ assert result == "A red apple on a white table"
+ assert len(inner.ainvoke_calls) == 1
+ assert len(captured_kwargs) == 1
+ bc_kwargs = captured_kwargs[0]
+ assert bc_kwargs["user_id"] == user_id
+ assert bc_kwargs["search_space_id"] == 99
+ assert bc_kwargs["billing_tier"] == "premium"
+ assert bc_kwargs["base_model"] == "openai/gpt-4o"
+ assert bc_kwargs["quota_reserve_tokens"] == 4000
+ assert bc_kwargs["usage_type"] == "vision_extraction"
+
+
+@pytest.mark.asyncio
+async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
+ from app.services.billable_calls import QuotaInsufficientError
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ @contextlib.asynccontextmanager
+ async def _denying_billable_call(**_kwargs):
+ raise QuotaInsufficientError(
+ usage_type="vision_extraction",
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield # unreachable but required for asynccontextmanager type
+
+ monkeypatch.setattr(
+ "app.services.quota_checked_vision_llm.billable_call",
+ _denying_billable_call,
+ raising=False,
+ )
+
+ inner = _FakeInnerLLM()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=uuid4(),
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ with pytest.raises(QuotaInsufficientError):
+ await wrapper.ainvoke([{"text": "x"}])
+
+ # Inner LLM never ran on a denied reservation.
+ assert inner.ainvoke_calls == []
+
+
+@pytest.mark.asyncio
+async def test_proxies_non_overridden_attributes_to_inner():
+ """``__getattr__`` forwards anything not on the proxy itself, so any
+ method we didn't explicitly override (``invoke``, ``astream``,
+ ``with_structured_output``, etc.) still works — just without quota
+ gating, which is fine because the indexer only ever calls ainvoke.
+ """
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ inner = _FakeInnerLLM()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=uuid4(),
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ # ``some_other_method`` is on the inner only.
+ assert wrapper.some_other_method(7) == 14
diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py
new file mode 100644
index 000000000..71fdee1c7
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py
@@ -0,0 +1,281 @@
+"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
+
+Capability is sourced from two places, in order of preference:
+
+1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
+ (authoritative — OpenRouter publishes per-model modalities directly).
+2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
+ YAML / BYOK configs that don't carry an explicit operator override.
+
+The catalog default is *True* (conservative-allow): an unknown / unmapped
+model is not pre-judged. The streaming-task safety net
+(``is_known_text_only_chat_model``) is the only place a False actually
+blocks a request — and it requires LiteLLM to *explicitly* mark the model
+as text-only.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.openrouter_integration_service import (
+ _OPENROUTER_DYNAMIC_MARKER,
+ _generate_configs,
+ _supports_image_input,
+)
+
+pytestmark = pytest.mark.unit
+
+
+_SETTINGS_BASE: dict = {
+ "api_key": "sk-or-test",
+ "id_offset": -10_000,
+ "rpm": 200,
+ "tpm": 1_000_000,
+ "free_rpm": 20,
+ "free_tpm": 100_000,
+ "anonymous_enabled_paid": False,
+ "anonymous_enabled_free": True,
+ "quota_reserve_tokens": 4000,
+}
+
+
+# ---------------------------------------------------------------------------
+# _supports_image_input helper (OpenRouter modalities)
+# ---------------------------------------------------------------------------
+
+
+def test_supports_image_input_true_for_multimodal():
+ assert (
+ _supports_image_input(
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ }
+ )
+ is True
+ )
+
+
+def test_supports_image_input_false_for_text_only():
+ """The exact failure mode the safety net guards against — DeepSeek V3
+ is a text-in/text-out model and would 404 if forwarded image_url."""
+ assert (
+ _supports_image_input(
+ {
+ "id": "deepseek/deepseek-v3.2-exp",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ }
+ )
+ is False
+ )
+
+
+def test_supports_image_input_false_when_modalities_missing():
+ """Defensive: missing architecture is treated as text-only at the
+ OpenRouter helper level. The wider catalog resolver
+ (`derive_supports_image_input`) only consults modalities when they
+ are non-empty, otherwise it falls back to LiteLLM."""
+ assert _supports_image_input({"id": "weird/model"}) is False
+ assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
+ assert (
+ _supports_image_input(
+ {"id": "weird/model", "architecture": {"input_modalities": None}}
+ )
+ is False
+ )
+
+
+# ---------------------------------------------------------------------------
+# _generate_configs threads the flag onto every emitted chat config
+# ---------------------------------------------------------------------------
+
+
+def test_generate_configs_emits_supports_image_input():
+ raw = [
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": ["tools"],
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ {
+ "id": "deepseek/deepseek-v3.2-exp",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": ["tools"],
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000003", "completion": "0.000015"},
+ },
+ ]
+ cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
+ by_model = {c["model_name"]: c for c in cfgs}
+
+ gpt = by_model["openai/gpt-4o"]
+ assert gpt["supports_image_input"] is True
+ assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
+
+ deepseek = by_model["deepseek/deepseek-v3.2-exp"]
+ assert deepseek["supports_image_input"] is False
+ assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
+
+
+# ---------------------------------------------------------------------------
+# YAML loader: defer to derive_supports_image_input on unannotated entries
+# ---------------------------------------------------------------------------
+
+
+def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
+ """The regression case: an Azure GPT-5.x YAML entry without a
+ ``supports_image_input`` override should resolve to True via LiteLLM's
+ model map (which says ``supports_vision: true``). Previously this
+ defaulted to False, blocking every image turn for vision-capable
+ YAML configs."""
+ yaml_dir = tmp_path / "app" / "config"
+ yaml_dir.mkdir(parents=True)
+ (yaml_dir / "global_llm_config.yaml").write_text(
+ """
+global_llm_configs:
+ - id: -2
+ name: Azure GPT-4o
+ provider: AZURE_OPENAI
+ model_name: gpt-4o
+ api_key: sk-test
+""",
+ encoding="utf-8",
+ )
+
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+ configs = config_module.load_global_llm_configs()
+ assert len(configs) == 1
+ assert configs[0]["supports_image_input"] is True
+
+
+def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
+ yaml_dir = tmp_path / "app" / "config"
+ yaml_dir.mkdir(parents=True)
+ (yaml_dir / "global_llm_config.yaml").write_text(
+ """
+global_llm_configs:
+ - id: -1
+ name: GPT-4o
+ provider: OPENAI
+ model_name: gpt-4o
+ api_key: sk-test
+ supports_image_input: false
+""",
+ encoding="utf-8",
+ )
+
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+ configs = config_module.load_global_llm_configs()
+ assert len(configs) == 1
+ # Operator override always wins, even against LiteLLM's True.
+ assert configs[0]["supports_image_input"] is False
+
+
+def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
+ """Unknown / unmapped model in YAML: default-allow. The streaming
+ safety net (which requires an explicit-False from LiteLLM) is the
+ only place a real block happens, so we don't lock the user out of
+ a freshly added third-party entry the catalog can't introspect."""
+ yaml_dir = tmp_path / "app" / "config"
+ yaml_dir.mkdir(parents=True)
+ (yaml_dir / "global_llm_config.yaml").write_text(
+ """
+global_llm_configs:
+ - id: -1
+ name: Some Brand New Model
+ provider: CUSTOM
+ custom_provider: brand_new_proxy
+ model_name: brand-new-model-x9
+ api_key: sk-test
+""",
+ encoding="utf-8",
+ )
+
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+ configs = config_module.load_global_llm_configs()
+ assert len(configs) == 1
+ assert configs[0]["supports_image_input"] is True
+
+
+# ---------------------------------------------------------------------------
+# AgentConfig threads the flag through both YAML and Auto / BYOK
+# ---------------------------------------------------------------------------
+
+
+def test_agent_config_from_yaml_explicit_overrides_resolver():
+ from app.agents.new_chat.llm_config import AgentConfig
+
+ cfg_text_only = AgentConfig.from_yaml_config(
+ {
+ "id": -1,
+ "name": "Text Only Override",
+ "provider": "openai",
+ "model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
+ "api_key": "sk-test",
+ "supports_image_input": False,
+ }
+ )
+ cfg_explicit_vision = AgentConfig.from_yaml_config(
+ {
+ "id": -2,
+ "name": "GPT-4o",
+ "provider": "openai",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "supports_image_input": True,
+ }
+ )
+ assert cfg_text_only.supports_image_input is False
+ assert cfg_explicit_vision.supports_image_input is True
+
+
+def test_agent_config_from_yaml_unannotated_uses_resolver():
+ """Without an explicit YAML key, AgentConfig defers to the catalog
+ resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
+ from app.agents.new_chat.llm_config import AgentConfig
+
+ cfg = AgentConfig.from_yaml_config(
+ {
+ "id": -1,
+ "name": "GPT-4o (no override)",
+ "provider": "openai",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ }
+ )
+ assert cfg.supports_image_input is True
+
+
+def test_agent_config_auto_mode_supports_image_input():
+ """Auto routes across the pool. We optimistically allow image input
+ so users can keep their selection on Auto with a vision-capable
+ deployment somewhere in the pool. The router's own `allowed_fails`
+ handles non-vision deployments via fallback."""
+ from app.agents.new_chat.llm_config import AgentConfig
+
+ auto = AgentConfig.from_auto_mode()
+ assert auto.supports_image_input is True
diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
new file mode 100644
index 000000000..63681828d
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
@@ -0,0 +1,515 @@
+"""Cost-based premium quota unit tests.
+
+Covers the USD-micro behaviour added in migration 140:
+
+* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
+ calls in a turn — used as the debit amount when ``agent_config.is_premium``
+ is true, regardless of which underlying model produced each call. This
+ preserves the prior "premium turn → all calls in turn count" rule from the
+ token-based system.
+* ``estimate_call_reserve_micros`` scales linearly with model pricing,
+ clamps to a sane floor when pricing is unknown, and respects the
+ ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
+ can't lock the whole balance on one call.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# TurnTokenAccumulator — premium-turn debit semantics
+# ---------------------------------------------------------------------------
+
+
+def test_total_cost_micros_sums_premium_and_free_calls():
+ """A premium turn that also called a free sub-agent debits the union.
+
+ The plan deliberately preserved the existing "premium turn → all calls
+ count" behaviour because per-call premium filtering relied on
+ ``LLMRouterService._premium_model_strings`` which only covers router-pool
+ deployments. ``total_cost_micros`` therefore must include free-model
+ calls (whose ``cost_micros`` is typically ``0``) as well as the premium
+ call's actual provider cost.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ # Premium model (e.g. claude-opus): non-zero cost.
+ acc.add(
+ model="anthropic/claude-3-5-sonnet",
+ prompt_tokens=1200,
+ completion_tokens=400,
+ total_tokens=1600,
+ cost_micros=12_345,
+ )
+ # Free sub-agent (e.g. title-gen on a free model): zero cost.
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=120,
+ completion_tokens=20,
+ total_tokens=140,
+ cost_micros=0,
+ )
+ # A second premium-priced call within the same turn.
+ acc.add(
+ model="anthropic/claude-3-5-sonnet",
+ prompt_tokens=800,
+ completion_tokens=200,
+ total_tokens=1000,
+ cost_micros=7_500,
+ )
+
+ assert acc.total_cost_micros == 12_345 + 0 + 7_500
+ # Token totals stay correct so the FE display path still works.
+ assert acc.grand_total == 1600 + 140 + 1000
+
+
+def test_total_cost_micros_zero_when_no_calls():
+ """An empty accumulator must report zero cost (no division-by-zero, no None)."""
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ assert acc.total_cost_micros == 0
+ assert acc.grand_total == 0
+
+
+def test_per_message_summary_groups_cost_by_model():
+ """``per_message_summary`` must accumulate ``cost_micros`` per model so the
+ SSE ``model_breakdown`` payload reports actual USD spend per provider.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="claude-3-5-sonnet",
+ prompt_tokens=100,
+ completion_tokens=50,
+ total_tokens=150,
+ cost_micros=4_000,
+ )
+ acc.add(
+ model="claude-3-5-sonnet",
+ prompt_tokens=200,
+ completion_tokens=100,
+ total_tokens=300,
+ cost_micros=8_000,
+ )
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=50,
+ completion_tokens=10,
+ total_tokens=60,
+ cost_micros=200,
+ )
+
+ summary = acc.per_message_summary()
+ assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
+ assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
+ assert summary["gpt-4o-mini"]["cost_micros"] == 200
+
+
+def test_serialized_calls_includes_cost_micros():
+ """``serialized_calls`` is what flows into the SSE ``call_details``
+ payload; cost_micros must be present on each entry so the FE message-info
+ dropdown can render per-call USD.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="m",
+ prompt_tokens=1,
+ completion_tokens=1,
+ total_tokens=2,
+ cost_micros=42,
+ )
+ serialized = acc.serialized_calls()
+ assert serialized == [
+ {
+ "model": "m",
+ "prompt_tokens": 1,
+ "completion_tokens": 1,
+ "total_tokens": 2,
+ "cost_micros": 42,
+ "call_kind": "chat",
+ }
+ ]
+
+
+# ---------------------------------------------------------------------------
+# estimate_call_reserve_micros — sizing and clamping
+# ---------------------------------------------------------------------------
+
+
+def test_reserve_returns_floor_when_model_unknown(monkeypatch):
+ """If LiteLLM doesn't know the model, ``get_model_info`` raises and the
+ helper falls back to the 100-micro floor — small enough that a user with
+ $0.0001 left can still send a tiny request, but non-zero so we still gate
+ against an empty balance.
+ """
+ import litellm
+
+ from app.services import token_quota_service
+
+ def _raise(_name):
+ raise KeyError("unknown")
+
+ monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="nonexistent-model",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
+ assert micros == 100
+
+
+def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
+ """LiteLLM may *return* a model with both cost-per-token fields at 0
+ (pricing not yet registered). The helper must not multiply 0 x tokens
+ and end up reserving 0 — it must clamp to the floor.
+ """
+ import litellm
+
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
+ raising=False,
+ )
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="some-pending-model",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
+
+
+def test_reserve_scales_with_model_cost(monkeypatch):
+ """Claude-Opus-priced model with 4000 reserve_tokens reserves
+ ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
+ some small artificial cap — that was the bug the plan called out.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 15e-6,
+ "output_cost_per_token": 75e-6,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="claude-3-opus",
+ quota_reserve_tokens=4000,
+ )
+ # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
+ assert micros == 360_000
+
+
+def test_reserve_clamps_to_max_ceiling(monkeypatch):
+ """A misconfigured "$1000 / M" model with 4000 reserve_tokens would
+ nominally compute to $4 = 4_000_000 micros. The ceiling
+ ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
+ can't lock the user's whole balance on one call.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 1e-3,
+ "output_cost_per_token": 0,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="oops-misconfigured",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == 1_000_000
+
+
+def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
+ """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
+ zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
+ so anonymous-style configs still reserve the operator-tunable default.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 1e-6,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
+ assert (
+ token_quota_service.estimate_call_reserve_micros(
+ base_model="cheap", quota_reserve_tokens=None
+ )
+ == 4000
+ )
+ assert (
+ token_quota_service.estimate_call_reserve_micros(
+ base_model="cheap", quota_reserve_tokens=0
+ )
+ == 4000
+ )
+
+
+# ---------------------------------------------------------------------------
+# TokenTrackingCallback — image vs chat usage shape
+# ---------------------------------------------------------------------------
+
+
+class _FakeImageUsage:
+ """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
+
+ def __init__(
+ self,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
+ total_tokens: int | None = None,
+ ) -> None:
+ self.input_tokens = input_tokens
+ self.output_tokens = output_tokens
+ if total_tokens is not None:
+ self.total_tokens = total_tokens
+
+
+class _FakeImageResponse:
+ """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
+ ``type(...).__name__`` probe routes to the image branch.
+ """
+
+ def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
+ self.usage = usage
+ if response_cost is not None:
+ self._hidden_params = {"response_cost": response_cost}
+
+
+# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
+# the callback. We can't simply name the class ``ImageResponse`` because
+# the test runner sometimes imports test modules in surprising ways and
+# we want to be explicit.
+_FakeImageResponse.__name__ = "ImageResponse"
+
+
+class _FakeChatUsage:
+ def __init__(self, prompt: int, completion: int):
+ self.prompt_tokens = prompt
+ self.completion_tokens = completion
+ self.total_tokens = prompt + completion
+
+
+class _FakeChatResponse:
+ def __init__(self, usage: _FakeChatUsage):
+ self.usage = usage
+
+
+@pytest.mark.asyncio
+async def test_callback_reads_image_usage_input_output_tokens():
+ """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
+ for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
+ prompt_tokens/completion_tokens which is the chat shape.
+ """
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeImageResponse(
+ usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
+ response_cost=0.04, # $0.04 per image
+ )
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+ assert len(acc.calls) == 1
+ call = acc.calls[0]
+ assert call.prompt_tokens == 42
+ assert call.completion_tokens == 8
+ assert call.total_tokens == 50
+ # 0.04 USD = 40_000 micros
+ assert call.cost_micros == 40_000
+ assert call.call_kind == "image_generation"
+
+
+@pytest.mark.asyncio
+async def test_callback_chat_path_unchanged():
+ """Chat responses must still read prompt_tokens/completion_tokens."""
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={
+ "model": "openrouter/anthropic/claude-3-5-sonnet",
+ "response_cost": 0.0036,
+ },
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+ assert len(acc.calls) == 1
+ call = acc.calls[0]
+ assert call.prompt_tokens == 120
+ assert call.completion_tokens == 30
+ assert call.total_tokens == 150
+ assert call.cost_micros == 3_600
+ assert call.call_kind == "chat"
+
+
+@pytest.mark.asyncio
+async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
+ """When OpenRouter omits ``usage.cost`` LiteLLM's
+ ``default_image_cost_calculator`` raises. The defensive image branch in
+ ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
+ chat-shaped and would raise too) — it returns 0 with a WARNING log.
+ """
+ import litellm
+
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ # Force completion_cost to raise the same way OpenRouter image-gen fails.
+ def _boom(*_args, **_kwargs):
+ raise ValueError("model_cost: missing entry for openrouter image model")
+
+ monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
+
+ # And make sure cost_per_token is NEVER called for the image path —
+ # if it were, our ``is_image=True`` branch is broken.
+ cost_per_token_calls: list = []
+
+ def _record_cost_per_token(**kwargs):
+ cost_per_token_calls.append(kwargs)
+ return (0.0, 0.0)
+
+ monkeypatch.setattr(
+ litellm, "cost_per_token", _record_cost_per_token, raising=False
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeImageResponse(
+ usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
+ )
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+
+ assert len(acc.calls) == 1
+ assert acc.calls[0].cost_micros == 0
+ assert acc.calls[0].call_kind == "image_generation"
+ # The image branch must short-circuit before cost_per_token.
+ assert cost_per_token_calls == []
+
+
+# ---------------------------------------------------------------------------
+# scoped_turn — ContextVar reset semantics (issue B)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_scoped_turn_restores_outer_accumulator():
+ """``scoped_turn`` must restore the previous ContextVar value on exit
+ so a per-call wrapper inside an outer chat turn doesn't leak its
+ accumulator outward (which would cause double-debit at chat-turn exit).
+ """
+ from app.services.token_tracking_service import (
+ get_current_accumulator,
+ scoped_turn,
+ start_turn,
+ )
+
+ outer = start_turn()
+ assert get_current_accumulator() is outer
+
+ async with scoped_turn() as inner:
+ assert get_current_accumulator() is inner
+ assert inner is not outer
+ inner.add(
+ model="x",
+ prompt_tokens=1,
+ completion_tokens=1,
+ total_tokens=2,
+ cost_micros=5,
+ )
+
+ # After exit the outer accumulator is restored unchanged.
+ assert get_current_accumulator() is outer
+ assert outer.total_cost_micros == 0
+ assert len(outer.calls) == 0
+ # The inner accumulator captured the call but didn't bleed into outer.
+ assert inner.total_cost_micros == 5
+
+
+@pytest.mark.asyncio
+async def test_scoped_turn_resets_to_none_when_no_outer():
+ """Running ``scoped_turn`` outside any chat turn (e.g. a background
+ indexing job) must leave the ContextVar at ``None`` on exit so the
+ next *unrelated* request starts clean.
+ """
+ from app.services.token_tracking_service import (
+ _turn_accumulator,
+ get_current_accumulator,
+ scoped_turn,
+ )
+
+ # ContextVar default is None for a fresh test isolated context. We
+ # simulate "no outer" explicitly to be robust against test order.
+ token = _turn_accumulator.set(None)
+ try:
+ assert get_current_accumulator() is None
+ async with scoped_turn() as acc:
+ assert get_current_accumulator() is acc
+ assert get_current_accumulator() is None
+ finally:
+ _turn_accumulator.reset(token)
diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py
new file mode 100644
index 000000000..b8ba9d80c
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py
@@ -0,0 +1,89 @@
+"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
+defaults from ``litellm.api_base`` either.
+
+Vision shares the same shape as image-gen — global YAML / OpenRouter
+dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
+call sites would silently drop the empty string and inherit
+``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
+construction so we test the kwargs we hand to it instead.
+"""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_get_vision_llm_global_openrouter_sets_api_base():
+ """Global negative-ID branch: an OpenRouter vision config with
+ ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
+ ``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
+ never silently absent."""
+ from app.services import llm_service
+
+ cfg = {
+ "id": -30_001,
+ "name": "GPT-4o Vision (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ "api_version": None,
+ "litellm_params": {},
+ "billing_tier": "free",
+ }
+
+ search_space = MagicMock()
+ search_space.id = 1
+ search_space.user_id = "user-x"
+ search_space.vision_llm_config_id = cfg["id"]
+
+ session = AsyncMock()
+ scalars = MagicMock()
+ scalars.first.return_value = search_space
+ result = MagicMock()
+ result.scalars.return_value = scalars
+ session.execute.return_value = result
+
+ captured: dict = {}
+
+ class FakeSanitized:
+ def __init__(self, **kwargs):
+ captured.update(kwargs)
+
+ with (
+ patch(
+ "app.services.vision_llm_router_service.get_global_vision_llm_config",
+ return_value=cfg,
+ ),
+ patch(
+ "app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
+ new=FakeSanitized,
+ ),
+ ):
+ await llm_service.get_vision_llm(session=session, search_space_id=1)
+
+ assert captured.get("api_base") == "https://openrouter.ai/api/v1"
+ assert captured["model"] == "openrouter/openai/gpt-4o"
+
+
+def test_vision_router_deployment_sets_api_base_when_config_empty():
+ """Auto-mode vision router: deployments are fed to ``litellm.Router``,
+ so the resolver has to apply at deployment construction time too."""
+ from app.services.vision_llm_router_service import VisionLLMRouterService
+
+ deployment = VisionLLMRouterService._config_to_deployment(
+ {
+ "model_name": "openai/gpt-4o",
+ "provider": "OPENROUTER",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ }
+ )
+ assert deployment is not None
+ assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
+ assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
new file mode 100644
index 000000000..c317eba20
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
@@ -0,0 +1,526 @@
+"""Unit tests for ``AssistantContentBuilder``.
+
+Pins the in-memory ``ContentPart[]`` projection so the JSONB the server
+persists matches what the frontend renders live (see
+``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both
+the structural shape of ``snapshot()`` and that the snapshot is
+``json.dumps``-safe (the streaming finally block writes it directly to
+``new_chat_messages.content`` without an explicit serialization round
+trip).
+"""
+
+from __future__ import annotations
+
+import json
+
+import pytest
+
+from app.tasks.chat.content_builder import AssistantContentBuilder
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _assert_jsonb_safe(parts: list[dict]) -> None:
+ """Sanity check: any snapshot must round-trip through ``json.dumps``."""
+ serialized = json.dumps(parts)
+ assert json.loads(serialized) == parts
+
+
+# ---------------------------------------------------------------------------
+# Text turns
+# ---------------------------------------------------------------------------
+
+
+class TestTextOnly:
+ def test_single_text_block_collapses_consecutive_deltas(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "Hello")
+ b.on_text_delta("text-1", " ")
+ b.on_text_delta("text-1", "world")
+ b.on_text_end("text-1")
+
+ snap = b.snapshot()
+ assert snap == [{"type": "text", "text": "Hello world"}]
+ assert not b.is_empty()
+ _assert_jsonb_safe(snap)
+
+ def test_empty_text_start_end_pair_leaves_no_part(self):
+ # Mirrors the FE: a text-start without any deltas should
+ # not materialise an empty ``{"type":"text","text":""}`` part.
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_end("text-1")
+
+ assert b.snapshot() == []
+ assert b.is_empty()
+
+ def test_text_after_text_end_starts_fresh_part(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "first")
+ b.on_text_end("text-1")
+
+ b.on_text_start("text-2")
+ b.on_text_delta("text-2", "second")
+ b.on_text_end("text-2")
+
+ snap = b.snapshot()
+ assert snap == [
+ {"type": "text", "text": "first"},
+ {"type": "text", "text": "second"},
+ ]
+
+
+class TestReasoningThenText:
+ def test_reasoning_followed_by_text_yields_two_parts_in_order(self):
+ b = AssistantContentBuilder()
+ b.on_reasoning_start("r-1")
+ b.on_reasoning_delta("r-1", "Considering options...")
+ b.on_reasoning_end("r-1")
+
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "The answer is 42.")
+ b.on_text_end("text-1")
+
+ snap = b.snapshot()
+ assert snap == [
+ {"type": "reasoning", "text": "Considering options..."},
+ {"type": "text", "text": "The answer is 42."},
+ ]
+ _assert_jsonb_safe(snap)
+
+ def test_text_delta_after_reasoning_implicitly_closes_reasoning(self):
+ # Mirrors FE ``appendText``: a text delta arriving while a
+ # reasoning part is "active" still produces a fresh text
+ # part, never appends into the reasoning block.
+ b = AssistantContentBuilder()
+ b.on_reasoning_start("r-1")
+ b.on_reasoning_delta("r-1", "thinking")
+ # No explicit reasoning_end — text delta should close it.
+ b.on_text_delta("text-1", "answer")
+
+ snap = b.snapshot()
+ assert snap == [
+ {"type": "reasoning", "text": "thinking"},
+ {"type": "text", "text": "answer"},
+ ]
+
+
+# ---------------------------------------------------------------------------
+# Tool calls
+# ---------------------------------------------------------------------------
+
+
+class TestToolHeavyTurn:
+ def test_full_tool_lifecycle_produces_complete_tool_call_part(self):
+ b = AssistantContentBuilder()
+ # Some narration before the tool fires.
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "Searching...")
+ b.on_text_end("text-1")
+
+ b.on_tool_input_start(
+ ui_id="call_run123",
+ tool_name="web_search",
+ langchain_tool_call_id="lc_tool_abc",
+ )
+ b.on_tool_input_delta("call_run123", '{"query":')
+ b.on_tool_input_delta("call_run123", '"surfsense"}')
+ b.on_tool_input_available(
+ ui_id="call_run123",
+ tool_name="web_search",
+ args={"query": "surfsense"},
+ langchain_tool_call_id="lc_tool_abc",
+ )
+ b.on_tool_output_available(
+ ui_id="call_run123",
+ output={"status": "completed", "citations": {}},
+ langchain_tool_call_id="lc_tool_abc",
+ )
+
+ snap = b.snapshot()
+ assert snap[0] == {"type": "text", "text": "Searching..."}
+ tool_part = snap[1]
+ assert tool_part["type"] == "tool-call"
+ assert tool_part["toolCallId"] == "call_run123"
+ assert tool_part["toolName"] == "web_search"
+ assert tool_part["args"] == {"query": "surfsense"}
+ # ``argsText`` is the pretty-printed final JSON, not the raw
+ # streaming buffer (FE ``stream-pipeline.ts:128``).
+ assert tool_part["argsText"] == json.dumps(
+ {"query": "surfsense"}, indent=2, ensure_ascii=False
+ )
+ assert tool_part["langchainToolCallId"] == "lc_tool_abc"
+ assert tool_part["result"] == {"status": "completed", "citations": {}}
+ _assert_jsonb_safe(snap)
+
+ def test_tool_input_available_without_prior_start_creates_card(self):
+ # Legacy / parity_v2-OFF path: tool-input-available may be
+ # emitted without a prior tool-input-start (no streamed
+ # tool_call_chunks). The card should still be created.
+ b = AssistantContentBuilder()
+ b.on_tool_input_available(
+ ui_id="call_run42",
+ tool_name="grep",
+ args={"pattern": "TODO"},
+ langchain_tool_call_id="lc_x",
+ )
+ b.on_tool_output_available(
+ ui_id="call_run42",
+ output={"matches": 3},
+ langchain_tool_call_id="lc_x",
+ )
+
+ snap = b.snapshot()
+ assert len(snap) == 1
+ part = snap[0]
+ assert part["type"] == "tool-call"
+ assert part["toolCallId"] == "call_run42"
+ assert part["args"] == {"pattern": "TODO"}
+ assert part["langchainToolCallId"] == "lc_x"
+ assert part["result"] == {"matches": 3}
+
+ def test_tool_input_start_idempotent_for_same_ui_id(self):
+ # parity_v2: tool-input-start can fire from BOTH the chunk
+ # registration path AND the canonical ``on_tool_start`` path.
+ # The second call must not create a duplicate part.
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_x", "ls", "lc_x")
+ b.on_tool_input_start("call_x", "ls", "lc_x")
+ snap = b.snapshot()
+ assert len(snap) == 1
+
+ def test_tool_input_delta_without_prior_start_is_silently_dropped(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}')
+ assert b.snapshot() == []
+
+ def test_langchain_tool_call_id_backfills_only_when_absent(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_x", "ls", "lc_first")
+ # Late event must NOT clobber an already-set lc id.
+ b.on_tool_input_start("call_x", "ls", "lc_late")
+ snap = b.snapshot()
+ assert snap[0]["langchainToolCallId"] == "lc_first"
+
+ def test_args_text_streaming_buffer_reflects_concatenation(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_x", "save_doc", "lc_y")
+ b.on_tool_input_delta("call_x", '{"title":')
+ b.on_tool_input_delta("call_x", '"Hi"}')
+ # Snapshot mid-stream should see the partial buffer (the FE
+ # tolerates invalid JSON and renders it as-is).
+ mid = b.snapshot()
+ assert mid[0]["argsText"] == '{"title":"Hi"}'
+ # Then tool-input-available replaces with pretty-printed.
+ b.on_tool_input_available(
+ "call_x",
+ "save_doc",
+ {"title": "Hi"},
+ "lc_y",
+ )
+ final = b.snapshot()
+ assert final[0]["argsText"] == json.dumps(
+ {"title": "Hi"}, indent=2, ensure_ascii=False
+ )
+
+
+# ---------------------------------------------------------------------------
+# Thinking steps & separators
+# ---------------------------------------------------------------------------
+
+
+class TestThinkingSteps:
+ def test_first_thinking_step_unshifts_singleton_to_index_zero(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "Hello")
+ b.on_text_end("text-1")
+
+ b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"])
+
+ snap = b.snapshot()
+ # Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift).
+ assert snap[0]["type"] == "data-thinking-steps"
+ assert snap[0]["data"]["steps"] == [
+ {
+ "id": "step-1",
+ "title": "Analyzing",
+ "status": "in_progress",
+ "items": ["item-a"],
+ }
+ ]
+ assert snap[1] == {"type": "text", "text": "Hello"}
+
+ def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self):
+ b = AssistantContentBuilder()
+ b.on_thinking_step("step-1", "Analyzing", "in_progress", [])
+ b.on_thinking_step("step-2", "Searching", "in_progress", ["q"])
+ b.on_thinking_step("step-1", "Analyzing", "completed", ["done"])
+
+ snap = b.snapshot()
+ assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1
+ steps = snap[0]["data"]["steps"]
+ assert len(steps) == 2
+ assert steps[0]["id"] == "step-1"
+ assert steps[0]["status"] == "completed"
+ assert steps[0]["items"] == ["done"]
+ assert steps[1]["id"] == "step-2"
+
+ def test_thinking_step_with_text_continues_appending_to_text(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "first")
+
+ # Thinking step inserts at index 0, bumps text idx from 0 to 1.
+ b.on_thinking_step("step-1", "Working", "in_progress", [])
+ b.on_text_delta("text-1", " second")
+
+ snap = b.snapshot()
+ text_parts = [p for p in snap if p["type"] == "text"]
+ assert text_parts == [{"type": "text", "text": "first second"}]
+
+ def test_thinking_step_without_id_is_dropped(self):
+ b = AssistantContentBuilder()
+ b.on_thinking_step("", "noop", "in_progress", None)
+ assert b.snapshot() == []
+ assert b.is_empty()
+
+
+class TestStepSeparators:
+ def test_separator_no_op_before_any_content(self):
+ b = AssistantContentBuilder()
+ b.on_step_separator()
+ assert b.snapshot() == []
+
+ def test_separator_after_text_appends_with_step_index_zero(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "first")
+ b.on_text_end("text-1")
+
+ b.on_step_separator()
+
+ snap = b.snapshot()
+ assert snap[-1] == {
+ "type": "data-step-separator",
+ "data": {"stepIndex": 0},
+ }
+
+ def test_consecutive_separators_collapse_to_one(self):
+ b = AssistantContentBuilder()
+ b.on_text_delta("text-1", "x")
+ b.on_step_separator()
+ b.on_step_separator() # No-op: previous part is already a separator.
+ snap = b.snapshot()
+ assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1
+
+ def test_step_index_increments_across_separators(self):
+ b = AssistantContentBuilder()
+ b.on_text_delta("text-1", "a")
+ b.on_step_separator()
+ b.on_text_delta("text-2", "b")
+ b.on_step_separator()
+ snap = b.snapshot()
+ seps = [p for p in snap if p["type"] == "data-step-separator"]
+ assert [s["data"]["stepIndex"] for s in seps] == [0, 1]
+
+
+# ---------------------------------------------------------------------------
+# Interruption handling
+# ---------------------------------------------------------------------------
+
+
+class TestMarkInterrupted:
+ def test_running_tool_calls_get_state_aborted(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_a", "ls", "lc_a")
+ b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
+ # No tool-output-available — simulates client disconnect mid-tool.
+
+ b.mark_interrupted()
+
+ snap = b.snapshot()
+ assert snap[0]["state"] == "aborted"
+ assert "result" not in snap[0]
+
+ def test_completed_tool_calls_are_not_marked_aborted(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_a", "ls", "lc_a")
+ b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
+ b.on_tool_output_available("call_a", {"files": []}, "lc_a")
+
+ b.mark_interrupted()
+
+ snap = b.snapshot()
+ assert "state" not in snap[0]
+ assert snap[0]["result"] == {"files": []}
+
+ def test_open_text_block_keeps_accumulated_content(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("text-1")
+ b.on_text_delta("text-1", "partial")
+ # No on_text_end — disconnect mid-stream.
+
+ b.mark_interrupted()
+
+ snap = b.snapshot()
+ assert snap == [{"type": "text", "text": "partial"}]
+
+
+# ---------------------------------------------------------------------------
+# is_empty / snapshot semantics
+# ---------------------------------------------------------------------------
+
+
+class TestIsEmpty:
+ def test_fresh_builder_is_empty(self):
+ assert AssistantContentBuilder().is_empty()
+
+ def test_text_part_breaks_emptiness(self):
+ b = AssistantContentBuilder()
+ b.on_text_delta("text-1", "x")
+ assert not b.is_empty()
+
+ def test_tool_call_breaks_emptiness(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_x", "ls", None)
+ assert not b.is_empty()
+
+ def test_thinking_step_alone_does_not_break_emptiness(self):
+ # Mirrors the "status marker fallback" semantic: a turn that
+ # only emitted a thinking step before being interrupted should
+ # still be treated as empty for finalize_assistant_turn's
+ # status-marker substitution.
+ b = AssistantContentBuilder()
+ b.on_thinking_step("step-1", "Working", "in_progress", [])
+ assert b.is_empty()
+
+ def test_step_separator_alone_does_not_break_emptiness(self):
+ b = AssistantContentBuilder()
+ # Force a separator (it would normally no-op without content,
+ # but we simulate the underlying state to verify is_empty is
+ # not fooled by a stray separator).
+ b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}})
+ assert b.is_empty()
+
+
+class TestSnapshotSemantics:
+ def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_x", "ls", "lc_x")
+ b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
+ snap = b.snapshot()
+ # Mutate the returned snapshot — original should be untouched.
+ snap[0]["args"]["mutated"] = True
+ snap[0]["state"] = "tampered"
+
+ again = b.snapshot()
+ assert "mutated" not in again[0]["args"]
+ assert "state" not in again[0]
+
+ def test_snapshot_round_trips_through_json(self):
+ b = AssistantContentBuilder()
+ b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"])
+ b.on_text_delta("text-1", "answer")
+ b.on_tool_input_start("call_x", "ls", "lc_x")
+ b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
+ b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x")
+ b.on_step_separator()
+ snap = b.snapshot()
+
+ encoded = json.dumps(snap)
+ assert json.loads(encoded) == snap
+
+
+class TestStats:
+ """``stats()`` is the perf-log handle for [PERF] [stream_*]
+ finalize_payload lines. Pin the schema so an ops dashboard can
+ rely on these keys being present and meaningful.
+ """
+
+ def test_fresh_builder_reports_all_zeros(self):
+ b = AssistantContentBuilder()
+ s = b.stats()
+ assert s == {
+ "parts": 0,
+ "bytes": 2, # ``[]`` is two bytes
+ "text": 0,
+ "reasoning": 0,
+ "tool_calls": 0,
+ "tool_calls_completed": 0,
+ "tool_calls_aborted": 0,
+ "thinking_step_parts": 0,
+ "step_separators": 0,
+ }
+
+ def test_counts_each_part_type_independently(self):
+ b = AssistantContentBuilder()
+ b.on_text_start("t1")
+ b.on_text_delta("t1", "hi")
+ b.on_text_end("t1")
+ b.on_reasoning_start("r1")
+ b.on_reasoning_delta("r1", "thinking")
+ b.on_reasoning_end("r1")
+ b.on_thinking_step("step-1", "Analyzing", "completed", ["item"])
+ b.on_step_separator()
+ b.on_tool_input_start("call_done", "ls", "lc_done")
+ b.on_tool_input_available("call_done", "ls", {}, "lc_done")
+ b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
+ b.on_tool_input_start("call_running", "rm", "lc_running")
+ b.on_tool_input_available("call_running", "rm", {}, "lc_running")
+
+ s = b.stats()
+ assert s["text"] == 1
+ assert s["reasoning"] == 1
+ assert s["tool_calls"] == 2
+ assert s["tool_calls_completed"] == 1
+ assert s["tool_calls_aborted"] == 0
+ assert s["thinking_step_parts"] == 1
+ assert s["step_separators"] == 1
+ assert s["parts"] == sum(
+ [
+ s["text"],
+ s["reasoning"],
+ s["tool_calls"],
+ s["thinking_step_parts"],
+ s["step_separators"],
+ ]
+ )
+ assert s["bytes"] > 0
+
+ def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self):
+ b = AssistantContentBuilder()
+ b.on_tool_input_start("call_done", "ls", "lc_done")
+ b.on_tool_input_available("call_done", "ls", {}, "lc_done")
+ b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
+ b.on_tool_input_start("call_running", "rm", "lc_running")
+ b.on_tool_input_available("call_running", "rm", {}, "lc_running")
+
+ # Pre-interrupt: one completed, one still running (no result).
+ pre = b.stats()
+ assert pre["tool_calls_completed"] == 1
+ assert pre["tool_calls_aborted"] == 0
+
+ b.mark_interrupted()
+ post = b.stats()
+ assert post["tool_calls_completed"] == 1
+ assert post["tool_calls_aborted"] == 1
+ assert post["tool_calls"] == 2
+
+ def test_bytes_reflects_jsonb_payload_size(self):
+ # Each text-delta adds bytes monotonically — useful for catching
+ # an unbounded delta buffer regression in the perf signal.
+ b = AssistantContentBuilder()
+ b.on_text_start("t1")
+ b.on_text_delta("t1", "x" * 10)
+ small = b.stats()["bytes"]
+ b.on_text_delta("t1", "x" * 1000)
+ large = b.stats()["bytes"]
+ assert large > small + 900
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
index 9258d5cfe..60750396c 100644
--- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
+++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
@@ -51,22 +51,34 @@ class _FakeToolMessage:
tool_call_id: str | None = None
+@dataclass
+class _FakeInterrupt:
+ value: dict[str, Any]
+
+
+@dataclass
+class _FakeTask:
+ interrupts: tuple[_FakeInterrupt, ...] = ()
+
+
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
- def __init__(self) -> None:
+ def __init__(self, tasks: list[Any] | None = None) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
- # and an empty ``tasks`` list keeps the post-stream interrupt
- # check a no-op too.
+ # and empty ``tasks`` keep the post-stream interrupt check a no-op too.
self.values: dict[str, Any] = {}
- self.tasks: list[Any] = []
+ self.tasks: list[Any] = tasks or []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
- def __init__(self, events: list[dict[str, Any]]) -> None:
+ def __init__(
+ self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
+ ) -> None:
self._events = events
+ self._state = state or _FakeAgentState()
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
@@ -79,7 +91,7 @@ class _FakeAgent:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
- return _FakeAgentState()
+ return self._state
def _model_stream(
@@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
)
-async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
+async def _drain(
+ events: list[dict[str, Any]], state: _FakeAgentState | None = None
+) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
- agent = _FakeAgent(events)
+ agent = _FakeAgent(events, state=state)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
@@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id(
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"
+
+
+@pytest.mark.asyncio
+async def test_interrupt_request_uses_task_that_contains_interrupt(
+ parity_v2_on: None,
+) -> None:
+ interrupt_payload = {
+ "type": "calendar_event_create",
+ "action": {
+ "tool": "create_calendar_event",
+ "params": {"summary": "mom bday"},
+ },
+ "context": {},
+ }
+ state = _FakeAgentState(
+ tasks=[
+ _FakeTask(interrupts=()),
+ _FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
+ ]
+ )
+
+ payloads = await _drain([], state=state)
+
+ interrupts = _of_type(payloads, "data-interrupt-request")
+ assert len(interrupts) == 1
+ assert (
+ interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
+ )
diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
new file mode 100644
index 000000000..a5bb3f58a
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
@@ -0,0 +1,318 @@
+"""Regression tests for ``run_async_celery_task``.
+
+These tests pin down the production bug observed on 2026-05-02 where
+the video-presentation Celery task hung at ``[billable_call] finalize``
+because the shared ``app.db.engine`` had pooled asyncpg connections
+bound to a *previous* task's now-closed event loop. Reusing such a
+connection on a fresh loop crashes inside ``pool_pre_ping`` with::
+
+ AttributeError: 'NoneType' object has no attribute 'send'
+
+(the proactor is None because the loop is gone) and can hang forever
+inside the asyncpg ``Connection._cancel`` cleanup coroutine.
+
+The fix is ``run_async_celery_task``: a small helper that runs every
+async celery task body inside a fresh event loop and disposes the
+shared engine pool both before (defends against a previous task that
+crashed) and after (releases connections we opened on this loop).
+
+Tests here exercise the helper with a stub engine that records
+``dispose()`` calls and panics if a coroutine produced by one loop is
+awaited on another — mirroring the real asyncpg behaviour.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import gc
+import sys
+from collections.abc import Iterator
+from contextlib import contextmanager
+from unittest.mock import patch
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Stub engine that emulates the asyncpg-on-stale-loop crash
+# ---------------------------------------------------------------------------
+
+
+class _StaleLoopEngine:
+ """Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
+
+ ``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
+ the running event loop id so tests can assert it ran on *each*
+ fresh loop.
+ """
+
+ def __init__(self) -> None:
+ self.dispose_loop_ids: list[int] = []
+
+ async def dispose(self) -> None:
+ loop = asyncio.get_running_loop()
+ self.dispose_loop_ids.append(id(loop))
+
+
+@contextmanager
+def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
+ """Patch ``from app.db import engine as shared_engine`` lookup.
+
+ The helper imports lazily inside the function body, so we have to
+ patch the attribute on the already-loaded ``app.db`` module.
+ """
+ import app.db as app_db
+
+ original = getattr(app_db, "engine", None)
+ app_db.engine = stub # type: ignore[attr-defined]
+ try:
+ yield
+ finally:
+ if original is None:
+ with pytest.raises(AttributeError):
+ _ = app_db.engine
+ else:
+ app_db.engine = original # type: ignore[attr-defined]
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+def test_runner_returns_value_and_disposes_engine_around_call() -> None:
+ """Happy path: the coroutine result is returned, and the shared
+ engine is disposed both before and after the task body runs.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ async def _body() -> str:
+ # Engine should already have been disposed once before we run.
+ assert len(stub.dispose_loop_ids) == 1
+ return "ok"
+
+ with _patch_shared_engine(stub):
+ result = run_async_celery_task(_body)
+
+ assert result == "ok"
+ # Once before the body, once after (in finally).
+ assert len(stub.dispose_loop_ids) == 2
+ # Both disposes ran on the SAME (fresh) loop the task body used.
+ assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
+
+
+def test_runner_creates_fresh_loop_per_invocation() -> None:
+ """Each call must spin its own loop. Without this guarantee a
+ previous task's loop would be reused and the asyncpg-stale-loop
+ crash would never be avoided.
+ """
+ import app.tasks.celery_tasks as celery_tasks_pkg
+
+ stub = _StaleLoopEngine()
+ new_loop_calls = 0
+ closed_loops: list[bool] = []
+
+ real_new_event_loop = asyncio.new_event_loop
+
+ def _counting_new_loop() -> asyncio.AbstractEventLoop:
+ nonlocal new_loop_calls
+ new_loop_calls += 1
+ loop = real_new_event_loop()
+ # Hook close() so we can verify each loop was closed properly
+ # before the next one was created.
+ original_close = loop.close
+
+ def _tracked_close() -> None:
+ closed_loops.append(True)
+ original_close()
+
+ loop.close = _tracked_close # type: ignore[method-assign]
+ return loop
+
+ async def _body() -> None:
+ # Loop is alive and current at body execution time.
+ running = asyncio.get_running_loop()
+ assert not running.is_closed()
+
+ with (
+ _patch_shared_engine(stub),
+ patch.object(asyncio, "new_event_loop", _counting_new_loop),
+ ):
+ for _ in range(3):
+ celery_tasks_pkg.run_async_celery_task(_body)
+
+ assert new_loop_calls == 3
+ assert closed_loops == [True, True, True]
+ # Each invocation disposed twice (before + after).
+ assert len(stub.dispose_loop_ids) == 6
+
+
+def test_runner_disposes_engine_even_when_body_raises() -> None:
+ """Cleanup MUST run on the failure path too — otherwise stale
+ connections leak into the next task and cause the original hang.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ class _BoomError(RuntimeError):
+ pass
+
+ async def _body() -> None:
+ raise _BoomError("kaboom")
+
+ with _patch_shared_engine(stub), pytest.raises(_BoomError):
+ run_async_celery_task(_body)
+
+ assert len(stub.dispose_loop_ids) == 2 # before + after still ran
+
+
+def test_runner_swallows_dispose_errors() -> None:
+ """A flaky engine.dispose() must NEVER take down a celery task.
+
+ Production scenario: the very first dispose (before the body runs)
+ might hit a partially-initialised engine; the helper logs and
+ moves on. The task body still runs; the result is still returned.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ class _AngryEngine:
+ def __init__(self) -> None:
+ self.calls = 0
+
+ async def dispose(self) -> None:
+ self.calls += 1
+ raise RuntimeError("dispose() blew up")
+
+ stub = _AngryEngine()
+
+ async def _body() -> int:
+ return 42
+
+ with _patch_shared_engine(stub):
+ assert run_async_celery_task(_body) == 42
+
+ assert stub.calls == 2 # before + after both attempted
+
+
+def test_runner_propagates_value_from_async_body() -> None:
+ """Sanity: pass-through of any pickleable celery return value."""
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ async def _body() -> dict[str, object]:
+ return {"status": "ready", "video_presentation_id": 19}
+
+ with _patch_shared_engine(stub):
+ out = run_async_celery_task(_body)
+
+ assert out == {"status": "ready", "video_presentation_id": 19}
+
+
+def test_video_presentation_task_uses_runner_helper() -> None:
+ """Defence-in-depth: confirm the celery task module imports
+ ``run_async_celery_task``. If a future refactor inlines a
+ ``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
+ the original hang will return.
+ """
+ # The module's task body should not contain a manual new_event_loop
+ # call — that's exactly what the helper exists to centralise.
+ import inspect
+
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ src = inspect.getsource(video_presentation_tasks)
+ assert "run_async_celery_task" in src, (
+ "video_presentation_tasks.py must use run_async_celery_task; "
+ "manual asyncio.new_event_loop() in a celery task hangs on the "
+ "shared SQLAlchemy pool when reused across tasks."
+ )
+ assert "asyncio.new_event_loop" not in src, (
+ "video_presentation_tasks.py contains a raw asyncio.new_event_loop "
+ "call — route every async task through run_async_celery_task to "
+ "avoid the stale-pool hang."
+ )
+
+
+def test_podcast_task_uses_runner_helper() -> None:
+ """Symmetric assertion for the podcast task — same root cause, same
+ fix, same regression risk.
+ """
+ import inspect
+
+ from app.tasks.celery_tasks import podcast_tasks
+
+ src = inspect.getsource(podcast_tasks)
+ assert "run_async_celery_task" in src
+ assert "asyncio.new_event_loop" not in src
+
+
+def test_runner_runs_shutdown_asyncgens_before_close() -> None:
+ """If the task body created any async generators that didn't get
+ fully iterated, we must still call ``loop.shutdown_asyncgens()``
+ before closing — otherwise we leak event-loop bound resources
+ that re-emerge as ``RuntimeError: Event loop is closed`` later.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ async def _agen():
+ try:
+ yield 1
+ yield 2
+ finally:
+ pass
+
+ async def _body() -> None:
+ # Iterate the agen partially, then leave it dangling — exactly
+ # the situation shutdown_asyncgens() is designed to clean up.
+ async for v in _agen():
+ if v == 1:
+ break
+
+ with _patch_shared_engine(stub):
+ run_async_celery_task(_body)
+
+ # By the time the helper returns, garbage collection + shutdown_asyncgens
+ # should have ensured no live async-gen references remain. We don't
+ # assert agen.closed directly (it depends on GC ordering); the real
+ # contract is "no warnings, no event-loop-closed errors". A successful
+ # second invocation proves the loop was cleaned up properly.
+ with _patch_shared_engine(stub):
+ run_async_celery_task(_body)
+
+ # Force a GC pass to surface any 'coroutine was never awaited'
+ # warnings that would indicate the cleanup is broken.
+ gc.collect()
+
+
+def test_runner_uses_proactor_loop_on_windows() -> None:
+ """On Windows the celery worker preselects a Proactor policy so
+ subprocess (ffmpeg) calls work. The helper must not silently fall
+ back to a Selector loop and re-break video/podcast generation.
+ """
+ if not sys.platform.startswith("win"):
+ pytest.skip("Windows-specific event-loop policy assertion")
+
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ # Mirror the policy set at the top of every Windows celery task.
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
+
+ observed: list[str] = []
+
+ async def _body() -> None:
+ observed.append(type(asyncio.get_running_loop()).__name__)
+
+ with _patch_shared_engine(stub):
+ run_async_celery_task(_body)
+
+ assert observed == ["ProactorEventLoop"]
diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
new file mode 100644
index 000000000..699297df1
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
@@ -0,0 +1,388 @@
+"""Unit tests for podcast Celery task billing integration.
+
+Validates ``_generate_content_podcast`` correctly wraps
+``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
+search-space owner's billing decision, and degrades cleanly when the
+resolver fails or premium credit is exhausted.
+
+Coverage:
+
+* Happy-path free config: resolver → ``billable_call`` enters with
+ ``usage_type='podcast_generation'`` and the configured reserve override,
+ graph runs, podcast row flips to ``READY``.
+* Happy-path premium config: same wiring with ``billing_tier='premium'``.
+* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
+ graph is *not* invoked, podcast row flips to ``FAILED``, return dict
+ carries ``reason='premium_quota_exhausted'``.
+* Resolver failure: ``ValueError`` from the resolver → podcast row flips
+ to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from types import SimpleNamespace
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+ def filter(self, *_args, **_kwargs):
+ return self
+
+
+class _FakeSession:
+ def __init__(self, podcast):
+ self._podcast = podcast
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self._podcast)
+
+ async def commit(self):
+ self.commit_count += 1
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ return None
+
+
+class _FakeSessionMaker:
+ def __init__(self, session: _FakeSession):
+ self._session = session
+
+ def __call__(self):
+ return self._session
+
+
+def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
+ """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
+ inside helpers keeps this fixture cheap."""
+ return SimpleNamespace(
+ id=podcast_id,
+ title="Test Podcast",
+ thread_id=thread_id,
+ status=None,
+ podcast_transcript=None,
+ file_location=None,
+ )
+
+
+@contextlib.asynccontextmanager
+async def _ok_billable_call(**kwargs):
+ """Stand-in for ``billable_call`` that records its kwargs and yields a
+ no-op accumulator-shaped object."""
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+
+
+_CALL_LOG: list[dict[str, Any]] = []
+
+
+@contextlib.asynccontextmanager
+async def _denying_billable_call(**kwargs):
+ from app.services.billable_calls import QuotaInsufficientError
+
+ _CALL_LOG.append(kwargs)
+ raise QuotaInsufficientError(
+ usage_type=kwargs.get("usage_type", "?"),
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield SimpleNamespace() # pragma: no cover — for grammar only
+
+
+@contextlib.asynccontextmanager
+async def _settlement_failing_billable_call(**kwargs):
+ from app.services.billable_calls import BillingSettlementError
+
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+ raise BillingSettlementError(
+ usage_type=kwargs.get("usage_type", "?"),
+ user_id=kwargs["user_id"],
+ cause=RuntimeError("finalize failed"),
+ )
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_call_log():
+ _CALL_LOG.clear()
+ yield
+ _CALL_LOG.clear()
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
+ """Happy path: free billing tier still wraps the graph call so the
+ audit row is recorded. Verifies kwargs threading."""
+ from app.config import config as app_config
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=7, thread_id=99)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ assert search_space_id == 555
+ assert thread_id == 99
+ return user_id, "free", "openrouter/some-free-model"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {
+ "podcast_transcript": [
+ SimpleNamespace(speaker_id=0, dialog="Hi"),
+ SimpleNamespace(speaker_id=1, dialog="Hello"),
+ ],
+ "final_podcast_file_path": "/tmp/podcast.wav",
+ }
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=7,
+ source_content="hello world",
+ search_space_id=555,
+ user_prompt="make it short",
+ )
+
+ assert result["status"] == "ready"
+ assert result["podcast_id"] == 7
+ assert podcast.status == PodcastStatus.READY
+ assert podcast.file_location == "/tmp/podcast.wav"
+
+ assert len(_CALL_LOG) == 1
+ call = _CALL_LOG[0]
+ assert call["user_id"] == user_id
+ assert call["search_space_id"] == 555
+ assert call["billing_tier"] == "free"
+ assert call["base_model"] == "openrouter/some-free-model"
+ assert call["usage_type"] == "podcast_generation"
+ assert (
+ call["quota_reserve_micros_override"]
+ == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
+ )
+ # Background artifact audit rows intentionally omit the TokenUsage.thread_id
+ # FK to avoid coupling Celery audit commits to an active chat transaction.
+ assert "thread_id" not in call
+ assert call["call_details"] == {
+ "podcast_id": 7,
+ "title": "Test Podcast",
+ "thread_id": 99,
+ }
+ assert callable(call["billable_session_factory"])
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_premium_tier(monkeypatch):
+ """Premium resolution flows through to ``billable_call`` so the
+ reserve/finalize path triggers."""
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast()
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return user_id, "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ await podcast_tasks._generate_content_podcast(
+ podcast_id=7,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert _CALL_LOG[0]["billing_tier"] == "premium"
+ assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
+ """When ``billable_call`` denies the reservation, the graph never
+ runs and the podcast row flips to FAILED with the documented reason
+ code."""
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=8)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=8,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 8,
+ "reason": "premium_quota_exhausted",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+ assert graph_invoked == [] # Graph never ran on denied reservation.
+
+
+@pytest.mark.asyncio
+async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=10)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(
+ podcast_tasks, "billable_call", _settlement_failing_billable_call
+ )
+
+ async def _fake_graph_invoke(state, config):
+ return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=10,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 10,
+ "reason": "billing_settlement_failed",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+
+
+@pytest.mark.asyncio
+async def test_resolver_failure_marks_podcast_failed(monkeypatch):
+ """If the resolver raises (e.g. search-space deleted), the task fails
+ cleanly without invoking the graph."""
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=9)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _failing_resolver(sess, search_space_id, *, thread_id=None):
+ raise ValueError("Search space 555 not found")
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=9,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 9,
+ "reason": "billing_resolution_failed",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+ assert graph_invoked == []
diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py
new file mode 100644
index 000000000..792d059b0
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py
@@ -0,0 +1,119 @@
+"""Predicate-level test for the chat streaming safety net.
+
+The safety net in ``stream_new_chat`` rejects an image turn early with
+a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
+selected model is *known* to be text-only. The earlier round of this
+work used a strict opt-in flag (``supports_image_input`` defaulting to
+False on every YAML entry) which blocked vision-capable Azure GPT-5.x
+deployments — this is the regression we're fixing.
+
+The new predicate is :func:`is_known_text_only_chat_model`, which
+returns True only when LiteLLM's authoritative model map *explicitly*
+sets ``supports_vision=False``. Anything else (vision True, missing
+key, exception) returns False so the request flows through to the
+provider.
+
+We exercise the predicate directly here rather than driving the full
+``stream_new_chat`` generator — covering the gate in isolation keeps
+the test focused on the regression while the generator's wider behavior
+is exercised by the integration suite.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.provider_capabilities import is_known_text_only_chat_model
+
+pytestmark = pytest.mark.unit
+
+
+def test_safety_net_does_not_fire_for_azure_gpt_4o():
+ """Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
+ vision-capable per LiteLLM's model map. The previous round's
+ blanket-False default blocked it; the new predicate must NOT mark
+ it text-only."""
+ assert (
+ is_known_text_only_chat_model(
+ provider="AZURE_OPENAI",
+ model_name="my-azure-deployment",
+ base_model="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_safety_net_does_not_fire_for_unknown_model():
+ """Default-pass on unknown — the safety net only blocks definitive
+ text-only confirmations. A freshly added third-party model that
+ LiteLLM doesn't know about must flow through to the provider."""
+ assert (
+ is_known_text_only_chat_model(
+ provider="CUSTOM",
+ custom_provider="brand_new_proxy",
+ model_name="brand-new-model-x9",
+ )
+ is False
+ )
+
+
+def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
+ """Transient ``litellm.get_model_info`` exception ≠ block. The
+ helper swallows the error and treats it as 'unknown' → False."""
+ import app.services.provider_capabilities as pc
+
+ def _raise(**_kwargs):
+ raise RuntimeError("intentional test failure")
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_safety_net_fires_only_on_explicit_false(monkeypatch):
+ """Stub LiteLLM to assert the only path that returns True is the
+ explicit ``supports_vision=False`` case. Anything else (True,
+ None, missing key) returns False from the predicate."""
+ import app.services.provider_capabilities as pc
+
+ def _info_explicit_false(**_kwargs):
+ return {"supports_vision": False, "max_input_tokens": 8192}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="text-only-stub",
+ )
+ is True
+ )
+
+ def _info_true(**_kwargs):
+ return {"supports_vision": True}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="vision-stub",
+ )
+ is False
+ )
+
+ def _info_missing(**_kwargs):
+ return {"max_input_tokens": 8192}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="missing-key-stub",
+ )
+ is False
+ )
diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
new file mode 100644
index 000000000..423b64ddb
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
@@ -0,0 +1,398 @@
+"""Unit tests for video-presentation Celery task billing integration.
+
+Mirrors ``test_podcast_billing.py`` for the video-presentation task.
+Validates the same wrap-graph-in-billable_call pattern and ensures the
+larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
+threaded through.
+
+Coverage:
+
+* Free config: graph runs, ``billable_call`` invoked with the video
+ reserve override.
+* Premium config: same wiring with ``billing_tier='premium'``.
+* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
+* Resolver failure: row → FAILED with ``billing_resolution_failed``.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from types import SimpleNamespace
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+ def filter(self, *_args, **_kwargs):
+ return self
+
+
+class _FakeSession:
+ def __init__(self, video):
+ self._video = video
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self._video)
+
+ async def commit(self):
+ self.commit_count += 1
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ return None
+
+
+class _FakeSessionMaker:
+ def __init__(self, session: _FakeSession):
+ self._session = session
+
+ def __call__(self):
+ return self._session
+
+
+def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=video_id,
+ title="Test Presentation",
+ thread_id=thread_id,
+ status=None,
+ slides=None,
+ scene_codes=None,
+ )
+
+
+_CALL_LOG: list[dict[str, Any]] = []
+
+
+@contextlib.asynccontextmanager
+async def _ok_billable_call(**kwargs):
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+
+
+@contextlib.asynccontextmanager
+async def _denying_billable_call(**kwargs):
+ from app.services.billable_calls import QuotaInsufficientError
+
+ _CALL_LOG.append(kwargs)
+ raise QuotaInsufficientError(
+ usage_type=kwargs.get("usage_type", "?"),
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield SimpleNamespace() # pragma: no cover
+
+
+@contextlib.asynccontextmanager
+async def _settlement_failing_billable_call(**kwargs):
+ from app.services.billable_calls import BillingSettlementError
+
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+ raise BillingSettlementError(
+ usage_type=kwargs.get("usage_type", "?"),
+ user_id=kwargs["user_id"],
+ cause=RuntimeError("finalize failed"),
+ )
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_call_log():
+ _CALL_LOG.clear()
+ yield
+ _CALL_LOG.clear()
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
+ from app.config import config as app_config
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=11, thread_id=99)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ assert search_space_id == 777
+ assert thread_id == 99
+ return user_id, "free", "openrouter/some-free-model"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=11,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result["status"] == "ready"
+ assert result["video_presentation_id"] == 11
+ assert video.status == VideoPresentationStatus.READY
+
+ assert len(_CALL_LOG) == 1
+ call = _CALL_LOG[0]
+ assert call["user_id"] == user_id
+ assert call["search_space_id"] == 777
+ assert call["billing_tier"] == "free"
+ assert call["base_model"] == "openrouter/some-free-model"
+ assert call["usage_type"] == "video_presentation_generation"
+ assert (
+ call["quota_reserve_micros_override"]
+ == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
+ )
+ # Background artifact audit rows intentionally omit the TokenUsage.thread_id
+ # FK to avoid coupling Celery audit commits to an active chat transaction.
+ assert "thread_id" not in call
+ assert call["call_details"] == {
+ "video_presentation_id": 11,
+ "title": "Test Presentation",
+ "thread_id": 99,
+ }
+ assert callable(call["billable_session_factory"])
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_premium_tier(monkeypatch):
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video()
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return user_id, "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=11,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert _CALL_LOG[0]["billing_tier"] == "premium"
+ assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=12)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(
+ video_presentation_tasks, "billable_call", _denying_billable_call
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=12,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 12,
+ "reason": "premium_quota_exhausted",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+ assert graph_invoked == []
+
+
+@pytest.mark.asyncio
+async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=14)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "billable_call",
+ _settlement_failing_billable_call,
+ )
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=14,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 14,
+ "reason": "billing_settlement_failed",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+
+
+@pytest.mark.asyncio
+async def test_resolver_failure_marks_video_failed(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=13)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _failing_resolver(sess, search_space_id, *, thread_id=None):
+ raise ValueError("Search space 777 not found")
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _failing_resolver,
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=13,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 13,
+ "reason": "billing_resolution_failed",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+ assert graph_invoked == []
diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
index 034aa484c..208204ca9 100644
--- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
+++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
@@ -1,9 +1,21 @@
+import inspect
+import json
+import logging
+import re
+from pathlib import Path
+
import pytest
+import app.tasks.chat.stream_new_chat as stream_new_chat_module
+from app.agents.new_chat.errors import BusyError
+from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
StreamResult,
+ _classify_stream_exception,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
+ _extract_resolved_file_path,
+ _log_chat_stream_error,
_tool_output_has_error,
)
@@ -17,6 +29,39 @@ def test_tool_output_error_detection():
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
+def test_extract_resolved_file_path_prefers_structured_path():
+ assert (
+ _extract_resolved_file_path(
+ tool_name="write_file",
+ tool_output={"status": "completed", "path": "/docs/note.md"},
+ tool_input=None,
+ )
+ == "/docs/note.md"
+ )
+
+
+def test_extract_resolved_file_path_falls_back_to_tool_input():
+ assert (
+ _extract_resolved_file_path(
+ tool_name="edit_file",
+ tool_output={"status": "completed", "result": "updated"},
+ tool_input={"file_path": "/docs/edited.md"},
+ )
+ == "/docs/edited.md"
+ )
+
+
+def test_extract_resolved_file_path_does_not_parse_result_text():
+ assert (
+ _extract_resolved_file_path(
+ tool_name="write_file",
+ tool_output={"result": "Updated file /docs/from-text.md"},
+ tool_input=None,
+ )
+ is None
+ )
+
+
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
@@ -45,3 +90,507 @@ def test_contract_enforcement_local_only():
result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result)
+
+
+def _extract_chat_stream_payload(record_message: str) -> dict:
+ prefix = "[chat_stream_error] "
+ assert record_message.startswith(prefix)
+ return json.loads(record_message[len(prefix) :])
+
+
+def test_unified_chat_stream_error_log_schema(caplog):
+ with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
+ _log_chat_stream_error(
+ flow="new",
+ error_kind="server_error",
+ error_code="SERVER_ERROR",
+ severity="warn",
+ is_expected=False,
+ request_id="req-123",
+ thread_id=101,
+ search_space_id=202,
+ user_id="user-1",
+ message="Error during chat: boom",
+ )
+
+ record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
+ payload = _extract_chat_stream_payload(record.message)
+
+ required_keys = {
+ "event",
+ "flow",
+ "error_kind",
+ "error_code",
+ "severity",
+ "is_expected",
+ "request_id",
+ "thread_id",
+ "search_space_id",
+ "user_id",
+ "message",
+ }
+ assert required_keys.issubset(payload.keys())
+ assert payload["event"] == "chat_stream_error"
+ assert payload["flow"] == "new"
+ assert payload["error_code"] == "SERVER_ERROR"
+
+
+def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
+ with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
+ _log_chat_stream_error(
+ flow="resume",
+ error_kind="premium_quota_exhausted",
+ error_code="PREMIUM_QUOTA_EXHAUSTED",
+ severity="info",
+ is_expected=True,
+ request_id="req-premium",
+ thread_id=303,
+ search_space_id=404,
+ user_id="user-2",
+ message="Buy more tokens to continue with this model, or switch to a free model",
+ extra={"auto_fallback": False},
+ )
+
+ record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
+ payload = _extract_chat_stream_payload(record.message)
+ assert payload["event"] == "chat_stream_error"
+ assert payload["error_kind"] == "premium_quota_exhausted"
+ assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
+ assert payload["flow"] == "resume"
+ assert payload["is_expected"] is True
+ assert payload["auto_fallback"] is False
+
+
+def test_stream_error_emission_keeps_machine_error_codes():
+ source = inspect.getsource(stream_new_chat_module)
+ format_error_calls = re.findall(r"format_error\(", source)
+ emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
+
+ # All stream paths should route through one shared terminal error emitter.
+ assert len(format_error_calls) == 1
+ assert {
+ "PREMIUM_QUOTA_EXHAUSTED",
+ "SERVER_ERROR",
+ }.issubset(emitted_error_codes)
+ assert 'flow: Literal["new", "regenerate"] = "new"' in source
+ assert "_emit_stream_terminal_error" in source
+ assert "flow=flow" in source
+ assert 'flow="resume"' in source
+
+
+def test_stream_exception_classifies_rate_limited():
+ exc = Exception(
+ '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
+ )
+ kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
+ exc, flow_label="chat"
+ )
+ assert kind == "rate_limited"
+ assert code == "RATE_LIMITED"
+ assert severity == "warn"
+ assert is_expected is True
+ assert "temporarily rate-limited" in user_message
+ assert extra is None
+
+
+def test_stream_exception_classifies_openrouter_429_payload():
+ exc = Exception(
+ 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
+ '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
+ )
+ kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
+ exc, flow_label="chat"
+ )
+ assert kind == "rate_limited"
+ assert code == "RATE_LIMITED"
+ assert severity == "warn"
+ assert is_expected is True
+ assert "temporarily rate-limited" in user_message
+ assert extra is None
+
+
+@pytest.mark.asyncio
+async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
+ """``_preflight_llm`` is best-effort.
+
+ - On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
+ caller can drive the cooldown/repin branch.
+ - On any other transient failure it MUST swallow the error so the normal
+ stream path continues without surfacing preflight noise to the user.
+ """
+ from types import SimpleNamespace
+
+ from app.tasks.chat.stream_new_chat import _preflight_llm
+
+ class _RateLimitedError(Exception):
+ """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
+
+ rate_calls: list[dict] = []
+ other_calls: list[dict] = []
+
+ async def _fake_acompletion_429(**kwargs):
+ rate_calls.append(kwargs)
+ raise _RateLimitedError("simulated 429")
+
+ async def _fake_acompletion_other(**kwargs):
+ other_calls.append(kwargs)
+ raise RuntimeError("some unrelated transient failure")
+
+ fake_llm = SimpleNamespace(
+ model="openrouter/google/gemma-4-31b-it:free",
+ api_key="test",
+ api_base=None,
+ )
+
+ import litellm # type: ignore[import-not-found]
+
+ monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
+ with pytest.raises(_RateLimitedError):
+ await _preflight_llm(fake_llm)
+ assert len(rate_calls) == 1
+ assert rate_calls[0]["max_tokens"] == 1
+ assert rate_calls[0]["stream"] is False
+
+ monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
+ # MUST NOT raise: non-rate-limit failures are swallowed.
+ await _preflight_llm(fake_llm)
+ assert len(other_calls) == 1
+
+
+@pytest.mark.asyncio
+async def test_preflight_skipped_for_auto_router_model():
+ """Router-mode ``model='auto'`` has no single deployment to ping; the
+ LiteLLM router itself owns per-deployment rate-limit accounting, so the
+ preflight helper must short-circuit instead of issuing a probe."""
+ from types import SimpleNamespace
+
+ from app.tasks.chat.stream_new_chat import _preflight_llm
+
+ fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
+ # Should return without raising or making any LiteLLM call.
+ await _preflight_llm(fake_llm)
+
+
+@pytest.mark.asyncio
+async def test_settle_speculative_agent_build_swallows_exceptions():
+ """``_settle_speculative_agent_build`` MUST always return cleanly so the
+ caller can safely re-touch the request-scoped session afterwards.
+
+ The helper guards the parallel preflight + agent-build path: when the
+ speculative build is being discarded (429 or non-429 preflight failure)
+ we await it solely to release any in-flight ``AsyncSession`` usage —
+ the build's outcome is irrelevant. Any exception (including
+ ``CancelledError``) leaking out would skip the caller's recovery flow
+ and re-introduce the very session-concurrency hazard the helper exists
+ to prevent.
+ """
+ import asyncio
+
+ from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
+
+ async def _raises() -> None:
+ raise RuntimeError("speculative build crashed")
+
+ async def _succeeds() -> str:
+ return "agent"
+
+ async def _slow() -> None:
+ await asyncio.sleep(0.05)
+
+ for coro in (_raises(), _succeeds(), _slow()):
+ task = asyncio.create_task(coro)
+ await _settle_speculative_agent_build(task)
+ assert task.done()
+
+
+@pytest.mark.asyncio
+async def test_settle_speculative_agent_build_handles_already_done_task():
+ """Done tasks (success or failure) must still be settled without raising."""
+ import asyncio
+
+ from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
+
+ async def _ok() -> str:
+ return "ok"
+
+ async def _bad() -> None:
+ raise ValueError("nope")
+
+ ok_task = asyncio.create_task(_ok())
+ bad_task = asyncio.create_task(_bad())
+ # Drive both to completion before settling.
+ await asyncio.sleep(0)
+ await asyncio.sleep(0)
+
+ await _settle_speculative_agent_build(ok_task)
+ await _settle_speculative_agent_build(bad_task)
+ assert ok_task.result() == "ok"
+ # ``bad_task`` exception was consumed by the settle helper; calling
+ # ``.exception()`` after the fact must still return the original error
+ # (the helper observes it but doesn't clear it).
+ assert isinstance(bad_task.exception(), ValueError)
+
+
+def test_stream_exception_classifies_thread_busy():
+ exc = BusyError(request_id="thread-123")
+ kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
+ exc, flow_label="chat"
+ )
+ assert kind == "thread_busy"
+ assert code == "THREAD_BUSY"
+ assert severity == "warn"
+ assert is_expected is True
+ assert "still finishing for this thread" in user_message
+ assert extra is None
+
+
+def test_stream_exception_classifies_thread_busy_from_message():
+ exc = Exception("Thread is busy with another request")
+ kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
+ exc, flow_label="chat"
+ )
+ assert kind == "thread_busy"
+ assert code == "THREAD_BUSY"
+ assert severity == "warn"
+ assert is_expected is True
+ assert "still finishing for this thread" in user_message
+ assert extra is None
+
+
+def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
+ thread_id = "thread-cancelling-1"
+ reset_cancel(thread_id)
+ request_cancel(thread_id)
+ exc = BusyError(request_id=thread_id)
+ kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
+ exc, flow_label="chat"
+ )
+ assert kind == "thread_busy"
+ assert code == "TURN_CANCELLING"
+ assert severity == "info"
+ assert is_expected is True
+ assert "stopping" in user_message
+ assert isinstance(extra, dict)
+ assert "retry_after_ms" in extra
+
+
+def test_premium_classification_is_error_code_driven():
+ classifier_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/lib/chat/chat-error-classifier.ts"
+ )
+ source = classifier_path.read_text(encoding="utf-8")
+
+ assert "PREMIUM_KEYWORDS" not in source
+ assert "RATE_LIMIT_KEYWORDS" not in source
+ assert "normalized.includes(" not in source
+ assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
+
+
+def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
+ page_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
+ )
+ source = page_path.read_text(encoding="utf-8")
+
+ assert "onPreAcceptFailure?: () => Promise;" in source
+ assert "if (!accepted) {" in source
+ assert "await onPreAcceptFailure?.();" in source
+ assert "await onAcceptedStreamError?.();" in source
+ assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
+ assert "setMessageDocumentsMap((prev) => {" in source
+
+
+def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
+ user_message_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/components/assistant-ui/user-message.tsx"
+ )
+ source = user_message_path.read_text(encoding="utf-8")
+
+ assert "Not sent. Edit and retry." not in source
+ assert "failed_pre_accept" not in source
+
+
+def test_network_send_failures_use_unified_retry_toast_message():
+ classifier_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/lib/chat/chat-error-classifier.ts"
+ )
+ classifier_source = classifier_path.read_text(encoding="utf-8")
+ request_errors_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/lib/chat/chat-request-errors.ts"
+ )
+ request_errors_source = request_errors_path.read_text(encoding="utf-8")
+
+ assert '"send_failed_pre_accept"' in classifier_source
+ assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
+ assert 'errorCode === "TURN_CANCELLING"' in classifier_source
+ assert "if (withCode.code) return withCode.code;" in classifier_source
+ assert 'userMessage: "Message not sent. Please retry."' in classifier_source
+ assert 'userMessage: "Connection issue. Please try again."' in classifier_source
+ assert "const passthroughCodes = new Set([" in request_errors_source
+ assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
+ assert '"THREAD_BUSY"' in request_errors_source
+ assert '"TURN_CANCELLING"' in request_errors_source
+ assert '"AUTH_EXPIRED"' in request_errors_source
+ assert '"UNAUTHORIZED"' in request_errors_source
+ assert '"RATE_LIMITED"' in request_errors_source
+ assert '"NETWORK_ERROR"' in request_errors_source
+ assert '"STREAM_PARSE_ERROR"' in request_errors_source
+ assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
+ assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
+ assert '"SERVER_ERROR"' in request_errors_source
+ assert "passthroughCodes.has(existingCode)" in request_errors_source
+ assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
+ assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
+ assert "Failed to start chat. Please try again." not in classifier_source
+
+
+def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
+ page_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
+ )
+ source = page_path.read_text(encoding="utf-8")
+
+ # Each flow tracks accepted boundary and passes it into shared terminal handling.
+ # The acceptance boundary is still meaningful post-refactor: it gates
+ # local-state cleanup (onPreAcceptFailure path) and lets the shared
+ # terminal handler distinguish pre-accept aborts from in-stream errors.
+ assert "let newAccepted = false;" in source
+ assert "let resumeAccepted = false;" in source
+ assert "let regenerateAccepted = false;" in source
+ assert "accepted: newAccepted," in source
+ assert "accepted: resumeAccepted," in source
+ assert "accepted: regenerateAccepted," in source
+
+ # NOTE: The FE-side persistence guards previously asserted here
+ # ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
+ # "if (newAccepted && !userPersisted) {") have been intentionally
+ # removed by the SSE-based message-id handshake refactor. Persistence
+ # is now server-authoritative: persist_user_turn / persist_assistant_shell
+ # run inside stream_new_chat / stream_resume_chat unconditionally and
+ # the FE consumes data-user-message-id / data-assistant-message-id
+ # SSE events to learn the canonical primary keys. There is therefore
+ # no FE call-site to guard, and the shared terminal handler relies
+ # purely on the `accepted` field above (forwarded to onAbort /
+ # onAcceptedStreamError) to drive UI cleanup. See
+ # tests/integration/chat/test_message_id_sse.py for the new
+ # cross-tier ID coherence guarantees.
+
+ # The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
+ # of the persistence refactor and must still exist on every
+ # start-stream fetch.
+ assert "const fetchWithTurnCancellingRetry = useCallback(" in source
+ assert "computeFallbackTurnCancellingRetryDelay" in source
+ assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
+ assert 'withMeta.errorCode === "THREAD_BUSY"' in source
+ assert "await fetchWithTurnCancellingRetry(() =>" in source
+
+
+def test_cancel_active_turn_route_contract_exists():
+ routes_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_backend/app/routes/new_chat_routes.py"
+ )
+ source = routes_path.read_text(encoding="utf-8")
+
+ assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
+ assert "response_model=CancelActiveTurnResponse" in source
+ assert 'status="cancelling",' in source
+ assert 'error_code="TURN_CANCELLING",' in source
+ assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
+ assert "retry_after_at=" in source
+ assert 'status="idle",' in source
+ assert 'error_code="NO_ACTIVE_TURN",' in source
+
+
+def test_turn_status_route_contract_exists():
+ routes_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_backend/app/routes/new_chat_routes.py"
+ )
+ source = routes_path.read_text(encoding="utf-8")
+
+ assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
+ assert "response_model=TurnStatusResponse" in source
+ assert "_build_turn_status_payload(thread_id)" in source
+ assert "Permission.CHATS_READ.value" in source
+ assert "_raise_if_thread_busy_for_start(" in source
+
+
+def test_turn_cancelling_retry_policy_contract_exists():
+ routes_path = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_backend/app/routes/new_chat_routes.py"
+ )
+ source = routes_path.read_text(encoding="utf-8")
+
+ assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
+ assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
+ assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
+ assert "def _compute_turn_cancelling_retry_delay(" in source
+ assert "retry-after-ms" in source
+ assert '"Retry-After"' in source
+ assert '"errorCode": "TURN_CANCELLING"' in source
+
+
+def test_turn_status_sse_contract_exists():
+ stream_source = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_backend/app/tasks/chat/stream_new_chat.py"
+ ).read_text(encoding="utf-8")
+ state_source = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/lib/chat/streaming-state.ts"
+ ).read_text(encoding="utf-8")
+ pipeline_source = (
+ Path(__file__).resolve().parents[3]
+ / "surfsense_web/lib/chat/stream-pipeline.ts"
+ ).read_text(encoding="utf-8")
+
+ assert '"turn-status"' in stream_source
+ assert '"status": "busy"' in stream_source
+ assert '"status": "idle"' in stream_source
+ assert 'type: "data-turn-status"' in state_source
+ assert 'case "data-turn-status":' in pipeline_source
+ assert "end_turn(str(chat_id))" in stream_source
+
+
+def test_chat_deepagent_forwards_resolved_model_name_to_both_builders():
+ """Regression guard: both system-prompt builders in chat_deepagent.py
+ must receive ``model_name=_resolve_prompt_model_name(...)`` so the
+ provider-variant dispatch can render the right ````
+ block. Without this the prompt silently falls back to the empty
+ ``"default"`` variant — the original bug being fixed.
+
+ This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes`
+ in style: it inspects module source text + a regex to enforce the
+ call-site shape, not just the wrapper layer (the wrappers already
+ forward ``model_name`` correctly, so testing them would not catch
+ the actual missed plumbing).
+ """
+ import app.agents.new_chat.chat_deepagent as chat_deepagent_module
+
+ source = inspect.getsource(chat_deepagent_module)
+
+ # Helper itself must be defined.
+ assert "def _resolve_prompt_model_name(" in source
+
+ # Both builder calls must forward the resolved model name. Match
+ # across newlines + whitespace because the kwargs are split over
+ # multiple lines.
+ pattern = re.compile(
+ r"build_(?:surfsense|configurable)_system_prompt\([^)]*"
+ r"model_name=_resolve_prompt_model_name\(",
+ re.DOTALL,
+ )
+ matches = pattern.findall(source)
+ assert len(matches) == 2, (
+ "Expected both system-prompt builder call sites to forward "
+ "`model_name=_resolve_prompt_model_name(...)`, found "
+ f"{len(matches)}"
+ )
diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock
index 209c42a9c..3e371cecc 100644
--- a/surfsense_backend/uv.lock
+++ b/surfsense_backend/uv.lock
@@ -62,7 +62,7 @@ wheels = [
[[package]]
name = "aiohttp"
-version = "3.13.5"
+version = "3.13.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohappyeyeballs" },
@@ -73,76 +73,76 @@ dependencies = [
{ name = "propcache" },
{ name = "yarl" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271 }
+sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876 },
- { url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557 },
- { url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258 },
- { url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199 },
- { url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013 },
- { url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501 },
- { url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981 },
- { url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934 },
- { url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671 },
- { url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219 },
- { url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049 },
- { url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557 },
- { url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931 },
- { url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125 },
- { url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427 },
- { url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534 },
- { url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446 },
- { url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930 },
- { url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927 },
- { url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141 },
- { url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476 },
- { url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507 },
- { url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465 },
- { url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523 },
- { url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113 },
- { url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351 },
- { url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205 },
- { url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618 },
- { url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185 },
- { url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311 },
- { url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147 },
- { url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356 },
- { url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637 },
- { url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896 },
- { url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721 },
- { url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663 },
- { url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094 },
- { url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701 },
- { url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360 },
- { url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023 },
- { url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795 },
- { url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405 },
- { url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082 },
- { url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346 },
- { url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891 },
- { url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113 },
- { url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088 },
- { url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976 },
- { url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444 },
- { url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128 },
- { url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029 },
- { url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758 },
- { url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883 },
- { url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668 },
- { url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461 },
- { url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661 },
- { url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800 },
- { url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382 },
- { url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724 },
- { url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027 },
- { url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644 },
- { url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630 },
- { url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403 },
- { url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924 },
- { url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119 },
- { url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072 },
- { url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819 },
- { url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441 },
+ { url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158 },
+ { url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037 },
+ { url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556 },
+ { url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314 },
+ { url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819 },
+ { url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279 },
+ { url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082 },
+ { url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938 },
+ { url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548 },
+ { url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669 },
+ { url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175 },
+ { url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049 },
+ { url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861 },
+ { url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003 },
+ { url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289 },
+ { url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185 },
+ { url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285 },
+ { url = "https://files.pythonhosted.org/packages/e3/ac/892f4162df9b115b4758d615f32ec63d00f3084c705ff5526630887b9b42/aiohttp-3.13.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:63dd5e5b1e43b8fb1e91b79b7ceba1feba588b317d1edff385084fcc7a0a4538", size = 745744 },
+ { url = "https://files.pythonhosted.org/packages/97/a9/c5b87e4443a2f0ea88cb3000c93a8fdad1ee63bffc9ded8d8c8e0d66efc6/aiohttp-3.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:746ac3cc00b5baea424dacddea3ec2c2702f9590de27d837aa67004db1eebc6e", size = 498178 },
+ { url = "https://files.pythonhosted.org/packages/94/42/07e1b543a61250783650df13da8ddcdc0d0a5538b2bd15cef6e042aefc61/aiohttp-3.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bda8f16ea99d6a6705e5946732e48487a448be874e54a4f73d514660ff7c05d3", size = 498331 },
+ { url = "https://files.pythonhosted.org/packages/20/d6/492f46bf0328534124772d0cf58570acae5b286ea25006900650f69dae0e/aiohttp-3.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b061e7b5f840391e3f64d0ddf672973e45c4cfff7a0feea425ea24e51530fc2", size = 1744414 },
+ { url = "https://files.pythonhosted.org/packages/e2/4d/e02627b2683f68051246215d2d62b2d2f249ff7a285e7a858dc47d6b6a14/aiohttp-3.13.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b252e8d5cd66184b570d0d010de742736e8a4fab22c58299772b0c5a466d4b21", size = 1719226 },
+ { url = "https://files.pythonhosted.org/packages/7b/6c/5d0a3394dd2b9f9aeba6e1b6065d0439e4b75d41f1fb09a3ec010b43552b/aiohttp-3.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20af8aad61d1803ff11152a26146d8d81c266aa8c5aa9b4504432abb965c36a0", size = 1782110 },
+ { url = "https://files.pythonhosted.org/packages/0d/2d/c20791e3437700a7441a7edfb59731150322424f5aadf635602d1d326101/aiohttp-3.13.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:13a5cc924b59859ad2adb1478e31f410a7ed46e92a2a619d6d1dd1a63c1a855e", size = 1884809 },
+ { url = "https://files.pythonhosted.org/packages/c8/94/d99dbfbd1924a87ef643833932eb2a3d9e5eee87656efea7d78058539eff/aiohttp-3.13.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:534913dfb0a644d537aebb4123e7d466d94e3be5549205e6a31f72368980a81a", size = 1764938 },
+ { url = "https://files.pythonhosted.org/packages/49/61/3ce326a1538781deb89f6cf5e094e2029cd308ed1e21b2ba2278b08426f6/aiohttp-3.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:320e40192a2dcc1cf4b5576936e9652981ab596bf81eb309535db7e2f5b5672f", size = 1570697 },
+ { url = "https://files.pythonhosted.org/packages/b6/77/4ab5a546857bb3028fbaf34d6eea180267bdab022ee8b1168b1fcde4bfdd/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9e587fcfce2bcf06526a43cb705bdee21ac089096f2e271d75de9c339db3100c", size = 1702258 },
+ { url = "https://files.pythonhosted.org/packages/79/63/d8f29021e39bc5af8e5d5e9da1b07976fb9846487a784e11e4f4eeda4666/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9eb9c2eea7278206b5c6c1441fdd9dc420c278ead3f3b2cc87f9b693698cc500", size = 1740287 },
+ { url = "https://files.pythonhosted.org/packages/55/3a/cbc6b3b124859a11bc8055d3682c26999b393531ef926754a3445b99dfef/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:29be00c51972b04bf9d5c8f2d7f7314f48f96070ca40a873a53056e652e805f7", size = 1753011 },
+ { url = "https://files.pythonhosted.org/packages/e0/30/836278675205d58c1368b21520eab9572457cf19afd23759216c04483048/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:90c06228a6c3a7c9f776fe4fc0b7ff647fffd3bed93779a6913c804ae00c1073", size = 1566359 },
+ { url = "https://files.pythonhosted.org/packages/50/b4/8032cc9b82d17e4277704ba30509eaccb39329dc18d6a35f05e424439e32/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:a533ec132f05fd9a1d959e7f34184cd7d5e8511584848dab85faefbaac573069", size = 1785537 },
+ { url = "https://files.pythonhosted.org/packages/17/7d/5873e98230bde59f493bf1f7c3e327486a4b5653fa401144704df5d00211/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1c946f10f413836f82ea4cfb90200d2a59578c549f00857e03111cf45ad01ca5", size = 1740752 },
+ { url = "https://files.pythonhosted.org/packages/7b/f2/13e46e0df051494d7d3c68b7f72d071f48c384c12716fc294f75d5b1a064/aiohttp-3.13.4-cp313-cp313-win32.whl", hash = "sha256:48708e2706106da6967eff5908c78ca3943f005ed6bcb75da2a7e4da94ef8c70", size = 433187 },
+ { url = "https://files.pythonhosted.org/packages/ea/c0/649856ee655a843c8f8664592cfccb73ac80ede6a8c8db33a25d810c12db/aiohttp-3.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:74a2eb058da44fa3a877a49e2095b591d4913308bb424c418b77beb160c55ce3", size = 459778 },
+ { url = "https://files.pythonhosted.org/packages/6d/29/6657cc37ae04cacc2dbf53fb730a06b6091cc4cbe745028e047c53e6d840/aiohttp-3.13.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:e0a2c961fc92abeff61d6444f2ce6ad35bb982db9fc8ff8a47455beacf454a57", size = 749363 },
+ { url = "https://files.pythonhosted.org/packages/90/7f/30ccdf67ca3d24b610067dc63d64dcb91e5d88e27667811640644aa4a85d/aiohttp-3.13.4-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:153274535985a0ff2bff1fb6c104ed547cec898a09213d21b0f791a44b14d933", size = 499317 },
+ { url = "https://files.pythonhosted.org/packages/93/13/e372dd4e68ad04ee25dafb050c7f98b0d91ea643f7352757e87231102555/aiohttp-3.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:351f3171e2458da3d731ce83f9e6b9619e325c45cbd534c7759750cabf453ad7", size = 500477 },
+ { url = "https://files.pythonhosted.org/packages/e5/fe/ee6298e8e586096fb6f5eddd31393d8544f33ae0792c71ecbb4c2bef98ac/aiohttp-3.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f989ac8bc5595ff761a5ccd32bdb0768a117f36dd1504b1c2c074ed5d3f4df9c", size = 1737227 },
+ { url = "https://files.pythonhosted.org/packages/b0/b9/a7a0463a09e1a3fe35100f74324f23644bfc3383ac5fd5effe0722a5f0b7/aiohttp-3.13.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d36fc1709110ec1e87a229b201dd3ddc32aa01e98e7868083a794609b081c349", size = 1694036 },
+ { url = "https://files.pythonhosted.org/packages/57/7c/8972ae3fb7be00a91aee6b644b2a6a909aedb2c425269a3bfd90115e6f8f/aiohttp-3.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42adaeea83cbdf069ab94f5103ce0787c21fb1a0153270da76b59d5578302329", size = 1786814 },
+ { url = "https://files.pythonhosted.org/packages/93/01/c81e97e85c774decbaf0d577de7d848934e8166a3a14ad9f8aa5be329d28/aiohttp-3.13.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:92deb95469928cc41fd4b42a95d8012fa6df93f6b1c0a83af0ffbc4a5e218cde", size = 1866676 },
+ { url = "https://files.pythonhosted.org/packages/5a/5f/5b46fe8694a639ddea2cd035bf5729e4677ea882cb251396637e2ef1590d/aiohttp-3.13.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0c7c07c4257ef3a1df355f840bc62d133bcdef5c1c5ba75add3c08553e2eed", size = 1740842 },
+ { url = "https://files.pythonhosted.org/packages/20/a2/0d4b03d011cca6b6b0acba8433193c1e484efa8d705ea58295590fe24203/aiohttp-3.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f062c45de8a1098cb137a1898819796a2491aec4e637a06b03f149315dff4d8f", size = 1566508 },
+ { url = "https://files.pythonhosted.org/packages/98/17/e689fd500da52488ec5f889effd6404dece6a59de301e380f3c64f167beb/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:76093107c531517001114f0ebdb4f46858ce818590363e3e99a4a2280334454a", size = 1700569 },
+ { url = "https://files.pythonhosted.org/packages/d8/0d/66402894dbcf470ef7db99449e436105ea862c24f7ea4c95c683e635af35/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:6f6ec32162d293b82f8b63a16edc80769662fbd5ae6fbd4936d3206a2c2cc63b", size = 1707407 },
+ { url = "https://files.pythonhosted.org/packages/2f/eb/af0ab1a3650092cbd8e14ef29e4ab0209e1460e1c299996c3f8288b3f1ff/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5903e2db3d202a00ad9f0ec35a122c005e85d90c9836ab4cda628f01edf425e2", size = 1752214 },
+ { url = "https://files.pythonhosted.org/packages/5a/bf/72326f8a98e4c666f292f03c385545963cc65e358835d2a7375037a97b57/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2d5bea57be7aca98dbbac8da046d99b5557c5cf4e28538c4c786313078aca09e", size = 1562162 },
+ { url = "https://files.pythonhosted.org/packages/67/9f/13b72435f99151dd9a5469c96b3b5f86aa29b7e785ca7f35cf5e538f74c0/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:bcf0c9902085976edc0232b75006ef38f89686901249ce14226b6877f88464fb", size = 1768904 },
+ { url = "https://files.pythonhosted.org/packages/18/bc/28d4970e7d5452ac7776cdb5431a1164a0d9cf8bd2fffd67b4fb463aa56d/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3295f98bfeed2e867cab588f2a146a9db37a85e3ae9062abf46ba062bd29165", size = 1723378 },
+ { url = "https://files.pythonhosted.org/packages/53/74/b32458ca1a7f34d65bdee7aef2036adbe0438123d3d53e2b083c453c24dd/aiohttp-3.13.4-cp314-cp314-win32.whl", hash = "sha256:a598a5c5767e1369d8f5b08695cab1d8160040f796c4416af76fd773d229b3c9", size = 438711 },
+ { url = "https://files.pythonhosted.org/packages/40/b2/54b487316c2df3e03a8f3435e9636f8a81a42a69d942164830d193beb56a/aiohttp-3.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:c555db4bc7a264bead5a7d63d92d41a1122fcd39cc62a4db815f45ad46f9c2c8", size = 464977 },
+ { url = "https://files.pythonhosted.org/packages/47/fb/e41b63c6ce71b07a59243bb8f3b457ee0c3402a619acb9d2c0d21ef0e647/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45abbbf09a129825d13c18c7d3182fecd46d9da3cfc383756145394013604ac1", size = 781549 },
+ { url = "https://files.pythonhosted.org/packages/97/53/532b8d28df1e17e44c4d9a9368b78dcb6bf0b51037522136eced13afa9e8/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:74c80b2bc2c2adb7b3d1941b2b60701ee2af8296fc8aad8b8bc48bc25767266c", size = 514383 },
+ { url = "https://files.pythonhosted.org/packages/1b/1f/62e5d400603e8468cd635812d99cb81cfdc08127a3dc474c647615f31339/aiohttp-3.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c97989ae40a9746650fa196894f317dafc12227c808c774929dda0ff873a5954", size = 518304 },
+ { url = "https://files.pythonhosted.org/packages/90/57/2326b37b10896447e3c6e0cbef4fe2486d30913639a5cfd1332b5d870f82/aiohttp-3.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dae86be9811493f9990ef44fff1685f5c1a3192e9061a71a109d527944eed551", size = 1893433 },
+ { url = "https://files.pythonhosted.org/packages/d2/b4/a24d82112c304afdb650167ef2fe190957d81cbddac7460bedd245f765aa/aiohttp-3.13.4-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1db491abe852ca2fa6cc48a3341985b0174b3741838e1341b82ac82c8bd9e871", size = 1755901 },
+ { url = "https://files.pythonhosted.org/packages/9e/2d/0883ef9d878d7846287f036c162a951968f22aabeef3ac97b0bea6f76d5d/aiohttp-3.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e5d701c0aad02a7dce72eef6b93226cf3734330f1a31d69ebbf69f33b86666e", size = 1876093 },
+ { url = "https://files.pythonhosted.org/packages/ad/52/9204bb59c014869b71971addad6778f005daa72a96eed652c496789d7468/aiohttp-3.13.4-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8ac32a189081ae0a10ba18993f10f338ec94341f0d5df8fff348043962f3c6f8", size = 1970815 },
+ { url = "https://files.pythonhosted.org/packages/d6/b5/e4eb20275a866dde0f570f411b36c6b48f7b53edfe4f4071aa1b0728098a/aiohttp-3.13.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98e968cdaba43e45c73c3f306fca418c8009a957733bac85937c9f9cf3f4de27", size = 1816223 },
+ { url = "https://files.pythonhosted.org/packages/d8/23/e98075c5bb146aa61a1239ee1ac7714c85e814838d6cebbe37d3fe19214a/aiohttp-3.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca114790c9144c335d538852612d3e43ea0f075288f4849cf4b05d6cd2238ce7", size = 1649145 },
+ { url = "https://files.pythonhosted.org/packages/d6/c1/7bad8be33bb06c2bb224b6468874346026092762cbec388c3bdb65a368ee/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ea2e071661ba9cfe11eabbc81ac5376eaeb3061f6e72ec4cc86d7cdd1ffbdbbb", size = 1816562 },
+ { url = "https://files.pythonhosted.org/packages/5c/10/c00323348695e9a5e316825969c88463dcc24c7e9d443244b8a2c9cf2eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:34e89912b6c20e0fd80e07fa401fd218a410aa1ce9f1c2f1dad6db1bd0ce0927", size = 1800333 },
+ { url = "https://files.pythonhosted.org/packages/84/43/9b2147a1df3559f49bd723e22905b46a46c068a53adb54abdca32c4de180/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0e217cf9f6a42908c52b46e42c568bd57adc39c9286ced31aaace614b6087965", size = 1820617 },
+ { url = "https://files.pythonhosted.org/packages/a9/7f/b3481a81e7a586d02e99387b18c6dafff41285f6efd3daa2124c01f87eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:0c296f1221e21ba979f5ac1964c3b78cfde15c5c5f855ffd2caab337e9cd9182", size = 1643417 },
+ { url = "https://files.pythonhosted.org/packages/8f/72/07181226bc99ce1124e0f89280f5221a82d3ae6a6d9d1973ce429d48e52b/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d99a9d168ebaffb74f36d011750e490085ac418f4db926cce3989c8fe6cb6b1b", size = 1849286 },
+ { url = "https://files.pythonhosted.org/packages/1a/e6/1b3566e103eca6da5be4ae6713e112a053725c584e96574caf117568ffef/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cb19177205d93b881f3f89e6081593676043a6828f59c78c17a0fd6c1fbed2ba", size = 1782635 },
+ { url = "https://files.pythonhosted.org/packages/37/58/1b11c71904b8d079eb0c39fe664180dd1e14bebe5608e235d8bfbadc8929/aiohttp-3.13.4-cp314-cp314t-win32.whl", hash = "sha256:c606aa5656dab6552e52ca368e43869c916338346bfaf6304e15c58fb113ea30", size = 472537 },
+ { url = "https://files.pythonhosted.org/packages/bc/8f/87c56a1a1977d7dddea5b31e12189665a140fdb48a71e9038ff90bb564ec/aiohttp-3.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:014dcc10ec8ab8db681f0d68e939d1e9286a5aa2b993cbbdb0db130853e02144", size = 506381 },
]
[[package]]
@@ -3723,7 +3723,7 @@ wheels = [
[[package]]
name = "litellm"
-version = "1.83.4"
+version = "1.83.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
@@ -3739,9 +3739,9 @@ dependencies = [
{ name = "tiktoken" },
{ name = "tokenizers" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/03/c4/30469c06ae7437a4406bc11e3c433cfd380a6771068cca15ea918dcd158f/litellm-1.83.4.tar.gz", hash = "sha256:6458d2030a41229460b321adee00517a91dbd8e63213cc953d355cb41d16f2d4", size = 17733899 }
+sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/b8/bd/df19d3f8f6654535ee343a341fd921f81c411abf601a53e3eaef58129b02/litellm-1.83.4-py3-none-any.whl", hash = "sha256:17d7b4d48d47aca988ea4f762ddda5e7bd72cda3270192b22813d0330869d7b4", size = 16015555 },
+ { url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054 },
]
[[package]]
@@ -5124,7 +5124,7 @@ wheels = [
[[package]]
name = "openai"
-version = "2.30.0"
+version = "2.24.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -5136,9 +5136,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084 }
+sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656 },
+ { url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122 },
]
[[package]]
@@ -6780,11 +6780,11 @@ wheels = [
[[package]]
name = "python-dotenv"
-version = "1.0.1"
+version = "1.2.2"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 }
+sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 },
+ { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101 },
]
[[package]]
@@ -7947,7 +7947,7 @@ wheels = [
[[package]]
name = "surf-new-backend"
-version = "0.0.19"
+version = "0.0.21"
source = { editable = "." }
dependencies = [
{ name = "alembic" },
@@ -8045,7 +8045,7 @@ requires-dist = [
{ name = "composio", specifier = ">=0.10.9" },
{ name = "datasets", specifier = ">=2.21.0" },
{ name = "daytona", specifier = ">=0.146.0" },
- { name = "deepagents", specifier = ">=0.4.12" },
+ { name = "deepagents", specifier = ">=0.4.12,<0.5" },
{ name = "discord-py", specifier = ">=2.5.2" },
{ name = "docling", specifier = ">=2.15.0" },
{ name = "elasticsearch", specifier = ">=9.1.1" },
@@ -8070,7 +8070,7 @@ requires-dist = [
{ name = "langgraph", specifier = ">=1.1.3" },
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
{ name = "linkup-sdk", specifier = ">=0.2.4" },
- { name = "litellm", specifier = ">=1.83.4" },
+ { name = "litellm", specifier = ">=1.83.7" },
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
{ name = "markdown", specifier = ">=3.7" },
{ name = "markdownify", specifier = ">=0.14.1" },
diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json
index 146dd177e..f127b85c0 100644
--- a/surfsense_browser_extension/package.json
+++ b/surfsense_browser_extension/package.json
@@ -1,7 +1,7 @@
{
"name": "surfsense_browser_extension",
"displayName": "Surfsense Browser Extension",
- "version": "0.0.19",
+ "version": "0.0.21",
"description": "Extension to collect Browsing History for SurfSense.",
"author": "https://github.com/MODSetter",
"engines": {
diff --git a/surfsense_desktop/build/entitlements.mac.plist b/surfsense_desktop/build/entitlements.mac.plist
new file mode 100644
index 000000000..5647e7759
--- /dev/null
+++ b/surfsense_desktop/build/entitlements.mac.plist
@@ -0,0 +1,35 @@
+
+
+
+
+
+ com.apple.security.cs.allow-jit
+
+ com.apple.security.cs.allow-unsigned-executable-memory
+
+
+
+ com.apple.security.cs.allow-dyld-environment-variables
+
+ com.apple.security.cs.disable-library-validation
+
+
+
+ com.apple.security.network.client
+
+ com.apple.security.network.server
+
+
+
+ com.apple.security.device.camera
+
+
+
+ com.apple.security.automation.apple-events
+
+
+
+ com.apple.security.files.user-selected.read-write
+
+
+
diff --git a/surfsense_desktop/electron-builder.yml b/surfsense_desktop/electron-builder.yml
index b0014a57b..e4e7670ec 100644
--- a/surfsense_desktop/electron-builder.yml
+++ b/surfsense_desktop/electron-builder.yml
@@ -46,8 +46,11 @@ mac:
icon: assets/icon.icns
category: public.app-category.productivity
artifactName: "${productName}-${version}-${arch}.${ext}"
- hardenedRuntime: false
+ hardenedRuntime: true
gatekeeperAssess: false
+ entitlements: build/entitlements.mac.plist
+ entitlementsInherit: build/entitlements.mac.plist
+ notarize: true
extendInfo:
NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists."
NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer."
diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json
index e2712d8ea..4826b904e 100644
--- a/surfsense_desktop/package.json
+++ b/surfsense_desktop/package.json
@@ -1,6 +1,6 @@
{
"name": "surfsense-desktop",
- "version": "0.0.19",
+ "version": "0.0.21",
"description": "SurfSense Desktop App",
"main": "dist/main.js",
"scripts": {
diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx
index 8d9ed5cb1..3ddd5195f 100644
--- a/surfsense_web/app/(home)/free/page.tsx
+++ b/surfsense_web/app/(home)/free/page.tsx
@@ -127,7 +127,7 @@ const FAQ_ITEMS = [
{
question: "What happens after I use my free tokens?",
answer:
- "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.",
+ "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.",
},
{
question: "Is Claude AI available without login?",
@@ -329,7 +329,7 @@ export default async function FreeHubPage() {
Want More Features?
- Create a free SurfSense account to unlock 3 million tokens, document uploads with
+ Create a free SurfSense account to unlock $5 of premium credit, document uploads with
citations, team collaboration, and integrations with Slack, Google Drive, Notion, and
30+ more tools.
diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx
index 6ad9435bf..2a413b9a9 100644
--- a/surfsense_web/app/(home)/pricing/page.tsx
+++ b/surfsense_web/app/(home)/pricing/page.tsx
@@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav";
export const metadata: Metadata = {
title: "Pricing | SurfSense - Free AI Search Plans",
description:
- "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.",
+ "Explore SurfSense plans and pricing. Start free with 500 pages & $5 in premium credits. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost.",
alternates: {
canonical: "https://surfsense.com/pricing",
},
diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
index 3017160e1..0c5662712 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
@@ -8,7 +8,7 @@ import { cn } from "@/lib/utils";
const TABS = [
{ id: "pages", label: "Pages" },
- { id: "tokens", label: "Premium Tokens" },
+ { id: "tokens", label: "Premium Credit" },
] as const;
type TabId = (typeof TABS)[number]["id"];
diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
index bb8f62703..9b5510df3 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
@@ -13,6 +13,7 @@ import { useParams } from "next/navigation";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { z } from "zod";
+import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
import {
clearTargetCommentIdAtom,
@@ -30,6 +31,7 @@ import {
clearPlanOwnerRegistry,
// extractWriteTodosFromContent,
} from "@/atoms/chat/plan-state.atom";
+import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom";
import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom";
import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
@@ -59,27 +61,35 @@ import { useMessagesSync } from "@/hooks/use-messages-sync";
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { getBearerToken } from "@/lib/auth-utils";
+import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier";
+import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors";
import { convertToThreadMessage } from "@/lib/chat/message-utils";
import {
isPodcastGenerating,
looksLikePodcastRequest,
setActivePodcastTaskId,
} from "@/lib/chat/podcast-state";
+import { createStreamFlushHelpers } from "@/lib/chat/stream-flush";
+import {
+ consumeSseEvents,
+ hasPersistableContent,
+ markInterruptsCompleted,
+ processSharedStreamEvent,
+} from "@/lib/chat/stream-pipeline";
+import {
+ applyTurnIdToAssistantMessageList,
+ mergeChatTurnIdIntoMessage,
+ readStreamedChatTurnId,
+ readStreamedMessageId,
+} from "@/lib/chat/stream-side-effects";
import {
- addStepSeparator,
addToolCall,
- appendReasoning,
- appendText,
- appendToolInputDelta,
buildContentForPersistence,
buildContentForUI,
type ContentPartsState,
- endReasoning,
- FrameBatchedUpdater,
- readSSEStream,
+ type FrameBatchedUpdater,
type ThinkingStepData,
type ToolUIGate,
- updateThinkingSteps,
updateToolCall,
} from "@/lib/chat/streaming-state";
import {
@@ -99,8 +109,9 @@ import {
import { NotFoundError } from "@/lib/error";
import { type BundleSubmit, HitlBundleProvider } from "@/lib/hitl";
import {
+ trackChatBlocked,
trackChatCreated,
- trackChatError,
+ trackChatErrorDetailed,
trackChatMessageSent,
trackChatResponseReceived,
} from "@/lib/posthog/events";
@@ -128,25 +139,6 @@ const MobileReportPanel = dynamic(
{ ssr: false }
);
-/**
- * After a tool produces output, mark any previously-decided interrupt tool
- * calls as completed so the ApprovalCard can transition from shimmer to done.
- */
-function markInterruptsCompleted(contentParts: Array<{ type: string; result?: unknown }>): void {
- for (const part of contentParts) {
- if (
- part.type === "tool-call" &&
- typeof part.result === "object" &&
- part.result !== null &&
- (part.result as Record).__interrupt__ === true &&
- (part.result as Record).__decided__ &&
- !(part.result as Record).__completed__
- ) {
- part.result = { ...(part.result as Record), __completed__: true };
- }
- }
-}
-
/**
* Generate a synthetic ``toolCallId`` for an action_request that has no
* matching streamed tool-call card (HITL-blocked subagent calls don't surface
@@ -243,28 +235,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
* ``stream_new_chat.py``) keep the JSON from ballooning.
*/
const TOOLS_WITH_UI_ALL: ToolUIGate = "all";
+const TURN_CANCELLING_INITIAL_DELAY_MS = 200;
+const TURN_CANCELLING_BACKOFF_FACTOR = 2;
+const TURN_CANCELLING_MAX_DELAY_MS = 1500;
+const RECENT_CANCEL_WINDOW_MS = 5_000;
-/**
- * When a streamed message is persisted, the backend returns the durable
- * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it
- * into the assistant-ui message metadata so the per-turn "Revert turn"
- * button can scope to this turn's actions even after a full chat reload.
- */
-function mergeChatTurnIdIntoMessage(
- msg: ThreadMessageLike,
- turnId: string | null | undefined
-): ThreadMessageLike {
- if (!turnId) return msg;
- const existingMeta = (msg.metadata ?? {}) as { custom?: Record };
- const existingCustom = existingMeta.custom ?? {};
- if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
- return {
- ...msg,
- metadata: {
- ...existingMeta,
- custom: { ...existingCustom, chatTurnId: turnId },
- },
- };
+function sleep(ms: number): Promise {
+ return new Promise((resolve) => setTimeout(resolve, ms));
+}
+
+function computeFallbackTurnCancellingRetryDelay(attempt: number): number {
+ const safeAttempt = Math.max(1, attempt);
+ const raw =
+ TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1);
+ return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS);
}
export default function NewChatPage() {
@@ -277,6 +261,7 @@ export default function NewChatPage() {
const [isRunning, setIsRunning] = useState(false);
const [tokenUsageStore] = useState(() => createTokenUsageStore());
const abortControllerRef = useRef(null);
+ const recentCancelRequestedAtRef = useRef(0);
const [pendingInterrupt, setPendingInterrupt] = useState<{
threadId: number;
assistantMsgId: string;
@@ -284,6 +269,63 @@ export default function NewChatPage() {
bundleToolCallIds: string[];
} | null>(null);
const toolsWithUI = TOOLS_WITH_UI_ALL;
+ const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom);
+
+ const persistAssistantErrorMessage = useCallback(
+ async ({
+ threadId,
+ assistantMsgId,
+ text,
+ }: {
+ threadId: number | null;
+ assistantMsgId: string;
+ text: string;
+ }) => {
+ setMessages((prev) =>
+ prev.map((m) =>
+ m.id === assistantMsgId
+ ? {
+ ...m,
+ content: [{ type: "text", text }],
+ }
+ : m
+ )
+ );
+
+ if (!threadId) return;
+
+ // Persist only temporary assistant placeholders to avoid duplicate rows
+ // when the message already has a database-backed ID.
+ if (!assistantMsgId.startsWith("msg-assistant-")) return;
+
+ try {
+ const savedMessage = await appendMessage(threadId, {
+ role: "assistant",
+ content: [{ type: "text", text }],
+ });
+ const newMsgId = `msg-${savedMessage.id}`;
+ tokenUsageStore.rename(assistantMsgId, newMsgId);
+ setMessages((prev) =>
+ prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
+ );
+ } catch (persistErr) {
+ console.error("Failed to persist assistant error message:", persistErr);
+ }
+ },
+ [tokenUsageStore]
+ );
+
+ // NOTE: ``persistUserTurn`` / ``persistAssistantTurn`` callbacks
+ // were removed in the SSE-based message ID handshake refactor.
+ // ``stream_new_chat`` and ``stream_resume_chat`` now persist both
+ // the user and assistant rows server-side via
+ // ``persist_user_turn`` / ``persist_assistant_shell`` and emit
+ // ``data-user-message-id`` / ``data-assistant-message-id`` SSE
+ // events; the consumers below rename the optimistic ids in real
+ // time. ``persistAssistantErrorMessage`` (above) is intentionally
+ // kept — it is the pre-stream-error fallback fired when the
+ // server NEVER accepted the request, and the BE has nothing to
+ // persist in that case.
// Get disabled tools from the tool toggle UI
const disabledTools = useAtomValue(disabledToolsAtom);
@@ -291,9 +333,10 @@ export default function NewChatPage() {
// Get mentioned document IDs from the composer.
const mentionedDocumentIds = useAtomValue(mentionedDocumentIdsAtom);
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
+ const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom);
- const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom);
const setCurrentThreadState = useSetAtom(currentThreadAtom);
+ const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom);
const setTargetCommentId = useSetAtom(setTargetCommentIdAtom);
const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom);
const closeReportPanel = useSetAtom(closeReportPanelAtom);
@@ -317,6 +360,8 @@ export default function NewChatPage() {
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
+ const { data: agentFlags } = useAtomValue(agentFlagsAtom);
+ const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true;
// Live collaboration: sync session state and messages via Zero
useChatSessionStateSync(threadId);
@@ -408,6 +453,143 @@ export default function NewChatPage() {
return Number.isNaN(parsed) ? 0 : parsed;
}, [params.chat_id]);
+ const handleChatFailure = useCallback(
+ async ({
+ error,
+ flow,
+ threadId,
+ assistantMsgId,
+ }: {
+ error: unknown;
+ flow: ChatFlow;
+ threadId: number | null;
+ assistantMsgId: string;
+ }) => {
+ const normalized = classifyChatError({
+ error,
+ flow,
+ context: {
+ searchSpaceId,
+ threadId,
+ },
+ });
+
+ const logger =
+ normalized.severity === "error"
+ ? console.error
+ : normalized.severity === "warn"
+ ? console.warn
+ : console.info;
+ logger(`[NewChatPage] ${flow} ${normalized.kind}:`, error);
+
+ const telemetryPayload = {
+ flow,
+ kind: normalized.kind,
+ error_code: normalized.errorCode,
+ severity: normalized.severity,
+ is_expected: normalized.isExpected,
+ message: normalized.userMessage,
+ };
+ if (normalized.telemetryEvent === "chat_blocked") {
+ trackChatBlocked(searchSpaceId, threadId, telemetryPayload);
+ } else {
+ trackChatErrorDetailed(searchSpaceId, threadId, telemetryPayload);
+ }
+
+ if (normalized.channel === "silent") {
+ return;
+ }
+
+ if (normalized.channel === "pinned_inline") {
+ if (threadId) {
+ setPremiumAlertForThread({
+ threadId,
+ message: normalized.userMessage,
+ userId: currentUser?.id ?? null,
+ });
+ }
+ if (normalized.assistantMessage) {
+ await persistAssistantErrorMessage({
+ threadId,
+ assistantMsgId,
+ text: normalized.assistantMessage,
+ });
+ }
+ return;
+ }
+
+ toast.error(normalized.userMessage);
+ },
+ [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread]
+ );
+
+ const handleStreamTerminalError = useCallback(
+ async ({
+ error,
+ flow,
+ threadId,
+ assistantMsgId,
+ accepted,
+ onAbort,
+ onPreAcceptFailure,
+ onAcceptedStreamError,
+ }: {
+ error: unknown;
+ flow: ChatFlow;
+ threadId: number | null;
+ assistantMsgId: string;
+ accepted: boolean;
+ onAbort?: () => Promise;
+ onPreAcceptFailure?: () => Promise;
+ onAcceptedStreamError?: () => Promise;
+ }) => {
+ if (error instanceof Error && error.name === "AbortError") {
+ await onAbort?.();
+ return;
+ }
+
+ if (!accepted) {
+ await onPreAcceptFailure?.();
+ } else {
+ await onAcceptedStreamError?.();
+ }
+
+ await handleChatFailure({
+ error: !accepted ? tagPreAcceptSendFailure(error) : error,
+ flow,
+ threadId,
+ assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant",
+ });
+ },
+ [handleChatFailure]
+ );
+
+ const fetchWithTurnCancellingRetry = useCallback(async (runFetch: () => Promise) => {
+ const maxAttempts = 4;
+ for (let attempt = 1; attempt <= maxAttempts; attempt += 1) {
+ const response = await runFetch();
+ if (response.ok) {
+ return response;
+ }
+ const error = await toHttpResponseError(response);
+ const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number };
+ const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING";
+ const isRecentThreadBusyAfterCancel =
+ withMeta.errorCode === "THREAD_BUSY" &&
+ Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS;
+ if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) {
+ const waitMs = withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt);
+ await sleep(waitMs);
+ continue;
+ }
+ throw error;
+ }
+
+ throw Object.assign(new Error("Turn cancellation retry limit exceeded"), {
+ errorCode: "TURN_CANCELLING",
+ });
+ }, []);
+
// Initialize thread and load messages
// For new chats (no urlChatId), we use lazy creation - thread is created on first message
const initializeThread = useCallback(async () => {
@@ -577,12 +759,39 @@ export default function NewChatPage() {
// Cancel ongoing request
const cancelRun = useCallback(async () => {
+ if (threadId) {
+ const token = getBearerToken();
+ if (token) {
+ const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
+ try {
+ const response = await fetch(
+ `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`,
+ {
+ method: "POST",
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
+ if (response.ok) {
+ const payload = (await response.json()) as {
+ error_code?: string;
+ };
+ if (payload.error_code === "TURN_CANCELLING") {
+ recentCancelRequestedAtRef.current = Date.now();
+ }
+ }
+ } catch (error) {
+ console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
+ }
+ }
+ }
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
setIsRunning(false);
- }, []);
+ }, [threadId]);
// Handle new message from user
const onNew = useCallback(
@@ -634,7 +843,12 @@ export default function NewChatPage() {
);
} catch (error) {
console.error("[NewChatPage] Failed to create thread:", error);
- toast.error("Failed to start chat. Please try again.");
+ await handleChatFailure({
+ error: tagPreAcceptSendFailure(error),
+ flow: "new",
+ threadId: currentThreadId,
+ assistantMsgId: "no-persist-assistant",
+ });
return;
}
}
@@ -643,8 +857,13 @@ export default function NewChatPage() {
setPendingUserImageUrls((prev) => prev.filter((u) => !urlsSnapshot.includes(u)));
}
- // Add user message to state
- const userMsgId = `msg-user-${Date.now()}`;
+ // Add user message to state. Mutable because the SSE
+ // ``data-user-message-id`` handler (below) renames this
+ // optimistic id to the canonical ``msg-{db_id}`` once the
+ // backend's ``persist_user_turn`` resolves the row, and
+ // the in-stream flush / interrupt closures need to see
+ // the post-rename value via this live ``let`` binding.
+ let userMsgId = `msg-user-${Date.now()}`;
// Always include author metadata so the UI layer can decide visibility
const authorMetadata = currentUser
@@ -710,72 +929,33 @@ export default function NewChatPage() {
}));
}
- const persistContent: unknown[] = [...userDisplayContent];
-
- if (allMentionedDocs.length > 0) {
- persistContent.push({
- type: "mentioned-documents",
- documents: allMentionedDocs,
- });
- }
-
- appendMessage(currentThreadId, {
- role: "user",
- content: persistContent,
- })
- .then((savedMessage) => {
- const newUserMsgId = `msg-${savedMessage.id}`;
- setMessages((prev) =>
- prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m))
- );
- setMessageDocumentsMap((prev) => {
- const docs = prev[userMsgId];
- if (!docs) return prev;
- const { [userMsgId]: _, ...rest } = prev;
- return { ...rest, [newUserMsgId]: docs };
- });
- if (isNewThread) {
- queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
- }
- })
- .catch((err) => console.error("Failed to persist user message:", err));
-
// Start streaming response
setIsRunning(true);
const controller = new AbortController();
abortControllerRef.current = controller;
- // Prepare assistant message
- const assistantMsgId = `msg-assistant-${Date.now()}`;
+ // Prepare assistant message. Mutable for the same reason
+ // as ``userMsgId`` above — the ``data-assistant-message-id``
+ // SSE handler reassigns this once
+ // ``persist_assistant_shell`` returns its canonical id.
+ let assistantMsgId = `msg-assistant-${Date.now()}`;
const currentThinkingSteps = new Map();
- const batcher = new FrameBatchedUpdater();
-
const contentPartsState: ContentPartsState = {
contentParts: [],
currentTextPartIndex: -1,
currentReasoningPartIndex: -1,
toolCallIndices: new Map(),
};
- const { contentParts, toolCallIndices } = contentPartsState;
+ const { contentParts } = contentPartsState;
let wasInterrupted = false;
- let tokenUsageData: Record | null = null;
- // Captured from ``data-turn-info`` at stream start.
- let streamedChatTurnId: string | null = null;
-
- // Add placeholder assistant message
- setMessages((prev) => [
- ...prev,
- {
- id: assistantMsgId,
- role: "assistant",
- content: [{ type: "text", text: "" }],
- createdAt: new Date(),
- },
- ]);
+ let newAccepted = false;
+ let streamBatcher: FrameBatchedUpdater | null = null;
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
- const selection = await getAgentFilesystemSelection(searchSpaceId);
+ const selection = await getAgentFilesystemSelection(searchSpaceId, {
+ localFilesystemEnabled,
+ });
if (
selection.filesystem_mode === "desktop_local_folder" &&
(!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0)
@@ -807,33 +987,59 @@ export default function NewChatPage() {
setMentionedDocuments([]);
}
- const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- Authorization: `Bearer ${token}`,
- },
- body: JSON.stringify({
- chat_id: currentThreadId,
- user_query: userQuery.trim(),
- search_space_id: searchSpaceId,
- filesystem_mode: selection.filesystem_mode,
- client_platform: selection.client_platform,
- local_filesystem_mounts: selection.local_filesystem_mounts,
- messages: messageHistory,
- mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
- mentioned_surfsense_doc_ids: hasSurfsenseDocIds
- ? mentionedDocumentIds.surfsense_doc_ids
- : undefined,
- disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
- ...(userImages.length > 0 ? { user_images: userImages } : {}),
- }),
- signal: controller.signal,
- });
+ const response = await fetchWithTurnCancellingRetry(() =>
+ fetch(`${backendUrl}/api/v1/new_chat`, {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Authorization: `Bearer ${token}`,
+ },
+ body: JSON.stringify({
+ chat_id: currentThreadId,
+ user_query: userQuery.trim(),
+ search_space_id: searchSpaceId,
+ filesystem_mode: selection.filesystem_mode,
+ client_platform: selection.client_platform,
+ local_filesystem_mounts: selection.local_filesystem_mounts,
+ messages: messageHistory,
+ mentioned_document_ids: hasDocumentIds
+ ? mentionedDocumentIds.document_ids
+ : undefined,
+ mentioned_surfsense_doc_ids: hasSurfsenseDocIds
+ ? mentionedDocumentIds.surfsense_doc_ids
+ : undefined,
+ // Full mention metadata so the BE can embed a
+ // ``mentioned-documents`` ContentPart on the
+ // persisted user message (replaces the old FE-side
+ // injection in ``persistUserTurn``).
+ mentioned_documents:
+ allMentionedDocs.length > 0
+ ? allMentionedDocs.map((d) => ({
+ id: d.id,
+ title: d.title,
+ document_type: d.document_type,
+ }))
+ : undefined,
+ disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
+ ...(userImages.length > 0 ? { user_images: userImages } : {}),
+ }),
+ signal: controller.signal,
+ })
+ );
if (!response.ok) {
- throw new Error(`Backend error: ${response.status}`);
+ throw await toHttpResponseError(response);
}
+ newAccepted = true;
+ setMessages((prev) => [
+ ...prev,
+ {
+ id: assistantMsgId,
+ role: "assistant",
+ content: [{ type: "text", text: "" }],
+ createdAt: new Date(),
+ },
+ ]);
const flushMessages = () => {
setMessages((prev) =>
@@ -844,123 +1050,41 @@ export default function NewChatPage() {
)
);
};
- const scheduleFlush = () => batcher.schedule(flushMessages);
- // Force-flush helper: ``batcher.flush()`` is a no-op when
- // ``dirty=false`` (e.g. a tool starts before any text
- // streamed). ``scheduleFlush(); batcher.flush()`` sets
- // the dirty bit FIRST so terminal events render
- // promptly without the 50ms throttle delay.
- const forceFlush = () => {
- scheduleFlush();
- batcher.flush();
- };
+ const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages);
+ streamBatcher = batcher;
- for await (const parsed of readSSEStream(response)) {
- switch (parsed.type) {
- case "text-delta":
- appendText(contentPartsState, parsed.delta);
- scheduleFlush();
- break;
-
- case "reasoning-delta":
- appendReasoning(contentPartsState, parsed.delta);
- scheduleFlush();
- break;
-
- case "reasoning-end":
- endReasoning(contentPartsState);
- scheduleFlush();
- break;
-
- case "start-step":
- addStepSeparator(contentPartsState);
- scheduleFlush();
- break;
-
- case "finish-step":
- break;
-
- case "tool-input-start":
- addToolCall(
- contentPartsState,
- toolsWithUI,
- parsed.toolCallId,
- parsed.toolName,
- {},
- false,
- parsed.langchainToolCallId
- );
- forceFlush();
- break;
-
- case "tool-input-delta":
- // High-frequency event: deltas can fire dozens
- // of times per call, so use throttled
- // scheduleFlush (NOT forceFlush) to coalesce.
- appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
- scheduleFlush();
- break;
-
- case "tool-input-available": {
- const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
- if (toolCallIndices.has(parsed.toolCallId)) {
- updateToolCall(contentPartsState, parsed.toolCallId, {
- args: parsed.input || {},
- argsText: finalArgsText,
- langchainToolCallId: parsed.langchainToolCallId,
- });
- } else {
- addToolCall(
- contentPartsState,
- toolsWithUI,
- parsed.toolCallId,
- parsed.toolName,
- parsed.input || {},
- false,
- parsed.langchainToolCallId
- );
- // addToolCall doesn't accept argsText today;
- // backfill via updateToolCall so the new card
- // renders pretty-printed JSON.
- updateToolCall(contentPartsState, parsed.toolCallId, {
- argsText: finalArgsText,
- });
- }
- forceFlush();
- break;
- }
-
- case "tool-output-available": {
- updateToolCall(contentPartsState, parsed.toolCallId, {
- result: parsed.output,
- langchainToolCallId: parsed.langchainToolCallId,
- });
- markInterruptsCompleted(contentParts);
- if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
- const idx = toolCallIndices.get(parsed.toolCallId);
- if (idx !== undefined) {
- const part = contentParts[idx];
- if (part?.type === "tool-call" && part.toolName === "generate_podcast") {
- setActivePodcastTaskId(String(parsed.output.podcast_id));
+ await consumeSseEvents(response, async (parsed) => {
+ if (
+ processSharedStreamEvent(parsed, {
+ contentPartsState,
+ toolsWithUI,
+ currentThinkingSteps,
+ scheduleFlush,
+ forceFlush,
+ onTokenUsage: (data) => {
+ tokenUsageStore.set(assistantMsgId, data);
+ },
+ onTurnStatus: (data) => {
+ if (data.status === "cancelling") {
+ recentCancelRequestedAtRef.current = Date.now();
+ }
+ },
+ onToolOutputAvailable: (event, sharedCtx) => {
+ if (event.output?.status === "pending" && event.output?.podcast_id) {
+ const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
+ if (idx !== undefined) {
+ const part = sharedCtx.contentPartsState.contentParts[idx];
+ if (part?.type === "tool-call" && part.toolName === "generate_podcast") {
+ setActivePodcastTaskId(String(event.output.podcast_id));
+ }
}
}
- }
- forceFlush();
- break;
- }
-
- case "data-thinking-step": {
- const stepData = parsed.data as ThinkingStepData;
- if (stepData?.id) {
- currentThinkingSteps.set(stepData.id, stepData);
- const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
- if (didUpdate) {
- scheduleFlush();
- }
- }
- break;
- }
-
+ },
+ })
+ ) {
+ return;
+ }
+ switch (parsed.type) {
case "data-thread-title-update": {
const titleData = parsed.data as { threadId: number; title: string };
if (titleData?.title && titleData?.threadId === currentThreadId) {
@@ -1006,13 +1130,21 @@ export default function NewChatPage() {
name: string;
args: Record;
}>;
- const paired = pairBundleToolCallIds(toolCallIndices, contentParts, actionRequests);
+ const paired = pairBundleToolCallIds(
+ contentPartsState.toolCallIndices,
+ contentPartsState.contentParts,
+ actionRequests
+ );
const bundleToolCallIds: string[] = [];
for (let i = 0; i < actionRequests.length; i++) {
const action = actionRequests[i];
let targetTcId = paired[i];
if (!targetTcId) {
- targetTcId = freshSynthToolCallId(toolCallIndices, action.name, i);
+ targetTcId = freshSynthToolCallId(
+ contentPartsState.toolCallIndices,
+ action.name,
+ i
+ );
addToolCall(
contentPartsState,
toolsWithUI,
@@ -1061,125 +1193,131 @@ export default function NewChatPage() {
}
case "data-turn-info": {
- streamedChatTurnId = parsed.data.chat_turn_id || null;
- if (streamedChatTurnId) {
+ const turnId = readStreamedChatTurnId(parsed.data);
+ if (turnId) {
setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
- )
+ applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
);
}
break;
}
- case "data-token-usage":
- tokenUsageData = parsed.data;
- tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
- break;
-
- case "error":
- throw new Error(parsed.errorText || "Server error");
- }
- }
-
- batcher.flush();
-
- // Skip persistence for interrupted messages -- handleResume will persist the final version
- const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
- if (contentParts.length > 0 && !wasInterrupted) {
- try {
- const savedMessage = await appendMessage(currentThreadId, {
- role: "assistant",
- content: finalContent,
- token_usage: tokenUsageData ?? undefined,
- turn_id: streamedChatTurnId,
- });
-
- // Update message ID from temporary to database ID so comments work immediately
- const newMsgId = `msg-${savedMessage.id}`;
- tokenUsageStore.rename(assistantMsgId, newMsgId);
- setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId
- ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
- : m
- )
- );
-
- // Update pending interrupt with the new persisted message ID
- setPendingInterrupt((prev) =>
- prev && prev.assistantMsgId === assistantMsgId
- ? { ...prev, assistantMsgId: newMsgId }
- : prev
- );
- } catch (err) {
- console.error("Failed to persist assistant message:", err);
- }
-
- // Track successful response
- trackChatResponseReceived(searchSpaceId, currentThreadId);
- }
- } catch (error) {
- batcher.dispose();
- if (error instanceof Error && error.name === "AbortError") {
- // Request was cancelled by user - persist partial response if any content was received
- const hasContent = contentParts.some(
- (part) =>
- (part.type === "text" && part.text.length > 0) ||
- (part.type === "reasoning" && part.text.length > 0) ||
- (part.type === "tool-call" &&
- (toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
- );
- if (hasContent && currentThreadId) {
- const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
- try {
- const savedMessage = await appendMessage(currentThreadId, {
- role: "assistant",
- content: partialContent,
- turn_id: streamedChatTurnId,
- });
-
- // Update message ID from temporary to database ID
- const newMsgId = `msg-${savedMessage.id}`;
+ case "data-user-message-id": {
+ // Server-authoritative user message id resolved by
+ // ``persist_user_turn`` (or recovered via ON CONFLICT).
+ // Rename the optimistic ``msg-user-XXX`` placeholder to
+ // the canonical ``msg-{db_id}`` so DB-id-gated UI
+ // (comments, edit-from-this-message) unlocks immediately,
+ // migrate the local mentioned-documents map, and reassign
+ // the closure variable so all downstream
+ // ``m.id === userMsgId`` checks see the new value.
+ const parsedMsg = readStreamedMessageId(parsed.data);
+ if (!parsedMsg) break;
+ const newUserMsgId = `msg-${parsedMsg.messageId}`;
+ const oldUserMsgId = userMsgId;
setMessages((prev) =>
prev.map((m) =>
- m.id === assistantMsgId
- ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
+ m.id === oldUserMsgId
+ ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, parsedMsg.turnId)
: m
)
);
- } catch (err) {
- console.error("Failed to persist partial assistant message:", err);
+ if (allMentionedDocs.length > 0) {
+ setMessageDocumentsMap((prev) => {
+ if (!(oldUserMsgId in prev)) {
+ return { ...prev, [newUserMsgId]: allMentionedDocs };
+ }
+ const { [oldUserMsgId]: _removed, ...rest } = prev;
+ return { ...rest, [newUserMsgId]: allMentionedDocs };
+ });
+ }
+ userMsgId = newUserMsgId;
+ if (isNewThread) {
+ // First user-side row landed in ``new_chat_messages``;
+ // refresh the sidebar so the freshly-bumped
+ // ``thread.updated_at`` reorders this thread.
+ queryClient.invalidateQueries({
+ queryKey: ["threads", String(searchSpaceId)],
+ });
+ }
+ break;
+ }
+
+ case "data-assistant-message-id": {
+ // Server-authoritative assistant message id resolved
+ // by ``persist_assistant_shell``. Rename the optimistic
+ // id, migrate ``tokenUsageStore`` so any pending
+ // ``data-token-usage`` payload binds to the new id,
+ // remap any in-flight ``pendingInterrupt`` reference,
+ // and reassign the closure variable so the in-stream
+ // flush callback (line ~1074) keeps writing to the
+ // renamed message.
+ const parsedMsg = readStreamedMessageId(parsed.data);
+ if (!parsedMsg) break;
+ const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
+ const oldAssistantMsgId = assistantMsgId;
+ tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
+ setMessages((prev) =>
+ prev.map((m) =>
+ m.id === oldAssistantMsgId
+ ? mergeChatTurnIdIntoMessage({ ...m, id: newAssistantMsgId }, parsedMsg.turnId)
+ : m
+ )
+ );
+ setPendingInterrupt((prev) =>
+ prev && prev.assistantMsgId === oldAssistantMsgId
+ ? { ...prev, assistantMsgId: newAssistantMsgId }
+ : prev
+ );
+ assistantMsgId = newAssistantMsgId;
+ break;
}
}
- return;
+ });
+
+ batcher.flush();
+
+ // Server-authoritative persistence: ``stream_new_chat``
+ // already wrote the user row in ``persist_user_turn``
+ // (the FE renamed the optimistic id mid-stream via
+ // ``data-user-message-id``) and finalises the assistant
+ // row in ``finalize_assistant_turn`` from a shielded
+ // ``finally`` block. Nothing left for the FE to persist
+ // here — track the response and unblock the UI.
+ if (contentParts.length > 0 && !wasInterrupted) {
+ trackChatResponseReceived(searchSpaceId, currentThreadId);
}
- console.error("[NewChatPage] Chat error:", error);
-
- // Track chat error
- trackChatError(
- searchSpaceId,
- currentThreadId,
- error instanceof Error ? error.message : "Unknown error"
- );
-
- toast.error("Failed to get response. Please try again.");
- // Update assistant message with error
- setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId
- ? {
- ...m,
- content: [
- {
- type: "text",
- text: "Sorry, there was an error. Please try again.",
- },
- ],
- }
- : m
- )
- );
+ } catch (error) {
+ streamBatcher?.dispose();
+ await handleStreamTerminalError({
+ error,
+ flow: "new",
+ threadId: currentThreadId,
+ assistantMsgId,
+ accepted: newAccepted,
+ // Server-side ``finalize_assistant_turn`` runs from a
+ // shielded ``anyio.CancelScope(shield=True)`` finally
+ // block, so partial content (incl. abort-mid-stream)
+ // is already persisted by the BE for the assistant
+ // row, and ``persist_user_turn`` ran before any LLM
+ // call. The FE's only remaining responsibility on
+ // abort / accepted-stream-error is to surface the
+ // error toast (handled by ``handleStreamTerminalError``
+ // itself).
+ onPreAcceptFailure: async () => {
+ // Pre-accept failure means the BE never accepted the
+ // request — no server-side persistence ran. Roll
+ // back the optimistic UI insertions we made before
+ // the fetch so the user message and any local
+ // mentioned-docs metadata don't linger.
+ setMessages((prev) => prev.filter((m) => m.id !== userMsgId));
+ setMessageDocumentsMap((prev) => {
+ if (!(userMsgId in prev)) return prev;
+ const { [userMsgId]: _removed, ...rest } = prev;
+ return rest;
+ });
+ },
+ });
} finally {
setIsRunning(false);
abortControllerRef.current = null;
@@ -1196,12 +1334,15 @@ export default function NewChatPage() {
setAgentCreatedDocuments,
queryClient,
currentUser,
+ localFilesystemEnabled,
disabledTools,
updateChatTabTitle,
tokenUsageStore,
pendingUserImageUrls,
setPendingUserImageUrls,
- toolsWithUI,
+ fetchWithTurnCancellingRetry,
+ handleStreamTerminalError,
+ handleChatFailure,
]
);
@@ -1214,7 +1355,12 @@ export default function NewChatPage() {
}>
) => {
if (!pendingInterrupt) return;
- const { threadId: resumeThreadId, assistantMsgId } = pendingInterrupt;
+ const { threadId: resumeThreadId } = pendingInterrupt;
+ // Destructured separately as ``let`` so the SSE
+ // ``data-assistant-message-id`` handler (resume always
+ // allocates a fresh server-side row) can rename it to
+ // the canonical ``msg-{db_id}`` mid-stream.
+ let assistantMsgId = pendingInterrupt.assistantMsgId;
setPendingInterrupt(null);
setIsRunning(true);
@@ -1229,7 +1375,6 @@ export default function NewChatPage() {
abortControllerRef.current = controller;
const currentThinkingSteps = new Map();
- const batcher = new FrameBatchedUpdater();
const contentPartsState: ContentPartsState = {
contentParts: [],
@@ -1238,9 +1383,8 @@ export default function NewChatPage() {
toolCallIndices: new Map(),
};
const { contentParts, toolCallIndices } = contentPartsState;
- let tokenUsageData: Record | null = null;
- // Captured from ``data-turn-info`` at stream start.
- let streamedChatTurnId: string | null = null;
+ let resumeAccepted = false;
+ let streamBatcher: FrameBatchedUpdater | null = null;
const existingMsg = messages.find((m) => m.id === assistantMsgId);
if (existingMsg && Array.isArray(existingMsg.content)) {
@@ -1318,27 +1462,32 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
- const selection = await getAgentFilesystemSelection(searchSpaceId);
- const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- Authorization: `Bearer ${token}`,
- },
- body: JSON.stringify({
- search_space_id: searchSpaceId,
- decisions,
- disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
- filesystem_mode: selection.filesystem_mode,
- client_platform: selection.client_platform,
- local_filesystem_mounts: selection.local_filesystem_mounts,
- }),
- signal: controller.signal,
+ const selection = await getAgentFilesystemSelection(searchSpaceId, {
+ localFilesystemEnabled,
});
+ const response = await fetchWithTurnCancellingRetry(() =>
+ fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Authorization: `Bearer ${token}`,
+ },
+ body: JSON.stringify({
+ search_space_id: searchSpaceId,
+ decisions,
+ disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
+ filesystem_mode: selection.filesystem_mode,
+ client_platform: selection.client_platform,
+ local_filesystem_mounts: selection.local_filesystem_mounts,
+ }),
+ signal: controller.signal,
+ })
+ );
if (!response.ok) {
- throw new Error(`Backend error: ${response.status}`);
+ throw await toHttpResponseError(response);
}
+ resumeAccepted = true;
const flushMessages = () => {
setMessages((prev) =>
@@ -1349,115 +1498,51 @@ export default function NewChatPage() {
)
);
};
- const scheduleFlush = () => batcher.schedule(flushMessages);
- const forceFlush = () => {
- scheduleFlush();
- batcher.flush();
- };
+ const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages);
+ streamBatcher = batcher;
- for await (const parsed of readSSEStream(response)) {
- switch (parsed.type) {
- case "text-delta":
- appendText(contentPartsState, parsed.delta);
- scheduleFlush();
- break;
-
- case "reasoning-delta":
- appendReasoning(contentPartsState, parsed.delta);
- scheduleFlush();
- break;
-
- case "reasoning-end":
- endReasoning(contentPartsState);
- scheduleFlush();
- break;
-
- case "start-step":
- addStepSeparator(contentPartsState);
- scheduleFlush();
- break;
-
- case "finish-step":
- break;
-
- case "tool-input-start":
- addToolCall(
- contentPartsState,
- toolsWithUI,
- parsed.toolCallId,
- parsed.toolName,
- {},
- false,
- parsed.langchainToolCallId
- );
- forceFlush();
- break;
-
- case "tool-input-delta":
- appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
- scheduleFlush();
- break;
-
- case "tool-input-available": {
- const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
- if (toolCallIndices.has(parsed.toolCallId)) {
- updateToolCall(contentPartsState, parsed.toolCallId, {
- args: parsed.input || {},
- argsText: finalArgsText,
- langchainToolCallId: parsed.langchainToolCallId,
- });
- } else {
- addToolCall(
- contentPartsState,
- toolsWithUI,
- parsed.toolCallId,
- parsed.toolName,
- parsed.input || {},
- false,
- parsed.langchainToolCallId
- );
- updateToolCall(contentPartsState, parsed.toolCallId, {
- argsText: finalArgsText,
- });
- }
- forceFlush();
- break;
- }
-
- case "tool-output-available":
- updateToolCall(contentPartsState, parsed.toolCallId, {
- result: parsed.output,
- langchainToolCallId: parsed.langchainToolCallId,
- });
- markInterruptsCompleted(contentParts);
- forceFlush();
- break;
-
- case "data-thinking-step": {
- const stepData = parsed.data as ThinkingStepData;
- if (stepData?.id) {
- currentThinkingSteps.set(stepData.id, stepData);
- const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
- if (didUpdate) {
- scheduleFlush();
+ await consumeSseEvents(response, async (parsed) => {
+ if (
+ processSharedStreamEvent(parsed, {
+ contentPartsState,
+ toolsWithUI,
+ currentThinkingSteps,
+ scheduleFlush,
+ forceFlush,
+ onTokenUsage: (data) => {
+ tokenUsageStore.set(assistantMsgId, data);
+ },
+ onTurnStatus: (data) => {
+ if (data.status === "cancelling") {
+ recentCancelRequestedAtRef.current = Date.now();
}
- }
- break;
- }
-
+ },
+ })
+ ) {
+ return;
+ }
+ switch (parsed.type) {
case "data-interrupt-request": {
const interruptData = parsed.data as Record;
const actionRequests = (interruptData.action_requests ?? []) as Array<{
name: string;
args: Record;
}>;
- const paired = pairBundleToolCallIds(toolCallIndices, contentParts, actionRequests);
+ const paired = pairBundleToolCallIds(
+ contentPartsState.toolCallIndices,
+ contentPartsState.contentParts,
+ actionRequests
+ );
const bundleToolCallIds: string[] = [];
for (let i = 0; i < actionRequests.length; i++) {
const action = actionRequests[i];
let targetTcId = paired[i];
if (!targetTcId) {
- targetTcId = freshSynthToolCallId(toolCallIndices, action.name, i);
+ targetTcId = freshSynthToolCallId(
+ contentPartsState.toolCallIndices,
+ action.name,
+ i
+ );
addToolCall(
contentPartsState,
toolsWithUI,
@@ -1504,64 +1589,74 @@ export default function NewChatPage() {
}
case "data-turn-info": {
- streamedChatTurnId = parsed.data.chat_turn_id || null;
- if (streamedChatTurnId) {
+ const turnId = readStreamedChatTurnId(parsed.data);
+ if (turnId) {
setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
- )
+ applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
);
}
break;
}
- case "data-token-usage":
- tokenUsageData = parsed.data;
- tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
+ case "data-assistant-message-id": {
+ // Resume always allocates a fresh ``new_chat_messages``
+ // row anchored to a new ``turn_id`` (the original
+ // interrupted turn's row stays as-is), so this is a
+ // real id swap. Rename the optimistic placeholder to
+ // ``msg-{db_id}`` and reassign closure state. Resume
+ // does NOT emit ``data-user-message-id`` — the user
+ // row belongs to the original interrupted turn.
+ const parsedMsg = readStreamedMessageId(parsed.data);
+ if (!parsedMsg) break;
+ const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
+ const oldAssistantMsgId = assistantMsgId;
+ tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
+ setMessages((prev) =>
+ prev.map((m) =>
+ m.id === oldAssistantMsgId
+ ? mergeChatTurnIdIntoMessage({ ...m, id: newAssistantMsgId }, parsedMsg.turnId)
+ : m
+ )
+ );
+ assistantMsgId = newAssistantMsgId;
break;
-
- case "error":
- throw new Error(parsed.errorText || "Server error");
+ }
}
- }
+ });
batcher.flush();
- const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
- if (contentParts.length > 0) {
- try {
- const savedMessage = await appendMessage(resumeThreadId, {
- role: "assistant",
- content: finalContent,
- token_usage: tokenUsageData ?? undefined,
- turn_id: streamedChatTurnId,
- });
- const newMsgId = `msg-${savedMessage.id}`;
- tokenUsageStore.rename(assistantMsgId, newMsgId);
- setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId
- ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
- : m
- )
- );
- } catch (err) {
- console.error("Failed to persist resumed assistant message:", err);
- }
- }
+ // Server-authoritative persistence: ``stream_resume_chat``
+ // finalises the assistant row in
+ // ``finalize_assistant_turn`` from a shielded
+ // ``finally`` block (covers both happy-path and
+ // abort-mid-stream). FE has no remaining persistence
+ // work here.
} catch (error) {
- batcher.dispose();
- if (error instanceof Error && error.name === "AbortError") {
- return;
- }
- console.error("[NewChatPage] Resume error:", error);
- toast.error("Failed to resume. Please try again.");
+ streamBatcher?.dispose();
+ await handleStreamTerminalError({
+ error,
+ flow: "resume",
+ threadId: resumeThreadId,
+ assistantMsgId,
+ accepted: resumeAccepted,
+ });
} finally {
setIsRunning(false);
abortControllerRef.current = null;
}
},
- [pendingInterrupt, messages, searchSpaceId, tokenUsageStore, toolsWithUI]
+ [
+ pendingInterrupt,
+ messages,
+ searchSpaceId,
+ localFilesystemEnabled,
+ disabledTools,
+ queryClient,
+ tokenUsageStore,
+ fetchWithTurnCancellingRetry,
+ handleStreamTerminalError,
+ ]
);
useEffect(() => {
@@ -1705,6 +1800,7 @@ export default function NewChatPage() {
editExtras?: {
userMessageContent: ThreadMessageLike["content"];
userImages: NewChatUserImagePayload[];
+ sourceUserMessageId?: string;
},
editFromPosition?: {
/** Message id (numeric, parsed from ``msg-``) to rewind to. */
@@ -1736,11 +1832,13 @@ export default function NewChatPage() {
let userQueryToDisplay: string | undefined;
let originalUserMessageContent: ThreadMessageLike["content"] | null = null;
let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined;
+ let sourceUserMessageId: string | undefined = editExtras?.sourceUserMessageId;
if (!isEdit) {
// Reload mode - find and preserve the last user message content
const lastUserMessage = [...messages].reverse().find((m) => m.role === "user");
if (lastUserMessage) {
+ sourceUserMessageId = lastUserMessage.id;
originalUserMessageContent = lastUserMessage.content;
originalUserMessageMetadata = lastUserMessage.metadata;
// Extract text for the API request
@@ -1755,34 +1853,17 @@ export default function NewChatPage() {
userQueryToDisplay = newUserQuery;
}
- // Remove downstream messages from the UI immediately. The
- // backend will also delete them from the database.
- //
- // When an explicit ``fromMessageId`` is passed, slice from
- // that message forward; otherwise fall back to the legacy
- // "drop the last 2" behaviour.
- setMessages((prev) => {
- if (editFromPosition?.fromMessageId != null) {
- const targetId = `msg-${editFromPosition.fromMessageId}`;
- const sliceIndex = prev.findIndex((m) => m.id === targetId);
- if (sliceIndex >= 0) {
- return prev.slice(0, sliceIndex);
- }
- }
- if (prev.length >= 2) {
- return prev.slice(0, -2);
- }
- return prev;
- });
-
// Start streaming
setIsRunning(true);
const controller = new AbortController();
abortControllerRef.current = controller;
- // Add placeholder user message if we have a new query (edit mode)
- const userMsgId = `msg-user-${Date.now()}`;
- const assistantMsgId = `msg-assistant-${Date.now()}`;
+ // Add placeholder user message if we have a new query (edit mode).
+ // Mutable for the same reason as in ``onNew`` — both ids are
+ // renamed mid-stream by the new ``data-user-message-id`` /
+ // ``data-assistant-message-id`` SSE handlers below.
+ let userMsgId = `msg-user-${Date.now()}`;
+ let assistantMsgId = `msg-assistant-${Date.now()}`;
const currentThinkingSteps = new Map();
const contentPartsState: ContentPartsState = {
@@ -1791,13 +1872,9 @@ export default function NewChatPage() {
currentReasoningPartIndex: -1,
toolCallIndices: new Map(),
};
- const { contentParts, toolCallIndices } = contentPartsState;
- const batcher = new FrameBatchedUpdater();
- let tokenUsageData: Record | null = null;
- // Captured from ``data-turn-info`` at stream start; stamped
- // onto persisted messages so future edits can locate the
- // right LangGraph checkpoint.
- let streamedChatTurnId: string | null = null;
+ const { contentParts } = contentPartsState;
+ let regenerateAccepted = false;
+ let streamBatcher: FrameBatchedUpdater | null = null;
// Add placeholder messages to UI
// Always add back the user message (with new query for edit, or original content for reload)
@@ -1810,21 +1887,14 @@ export default function NewChatPage() {
createdAt: new Date(),
metadata: isEdit ? undefined : originalUserMessageMetadata,
};
- setMessages((prev) => [...prev, userMessage]);
-
- // Add placeholder assistant message
- setMessages((prev) => [
- ...prev,
- {
- id: assistantMsgId,
- role: "assistant",
- content: [{ type: "text", text: "" }],
- createdAt: new Date(),
- },
- ]);
-
+ const sourceMentionedDocs =
+ sourceUserMessageId && messageDocumentsMap[sourceUserMessageId]
+ ? messageDocumentsMap[sourceUserMessageId]
+ : [];
try {
- const selection = await getAgentFilesystemSelection(searchSpaceId);
+ const selection = await getAgentFilesystemSelection(searchSpaceId, {
+ localFilesystemEnabled,
+ });
const requestBody: Record = {
search_space_id: searchSpaceId,
user_query: newUserQuery,
@@ -1832,6 +1902,18 @@ export default function NewChatPage() {
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
+ // Full mention metadata for the regenerate-specific
+ // source list. Only meaningful for edit (the BE only
+ // re-persists a user row when ``user_query`` is set);
+ // reload reuses the original turn's mentioned_documents.
+ mentioned_documents:
+ sourceMentionedDocs.length > 0
+ ? sourceMentionedDocs.map((d) => ({
+ id: d.id,
+ title: d.title,
+ document_type: d.document_type,
+ }))
+ : undefined,
};
if (isEdit) {
requestBody.user_images = editExtras?.userImages ?? [];
@@ -1846,18 +1928,56 @@ export default function NewChatPage() {
requestBody.revert_actions = true;
}
}
- const response = await fetch(getRegenerateUrl(threadId), {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- Authorization: `Bearer ${token}`,
- },
- body: JSON.stringify(requestBody),
- signal: controller.signal,
- });
+ const response = await fetchWithTurnCancellingRetry(() =>
+ fetch(getRegenerateUrl(threadId), {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Authorization: `Bearer ${token}`,
+ },
+ body: JSON.stringify(requestBody),
+ signal: controller.signal,
+ })
+ );
if (!response.ok) {
- throw new Error(`Backend error: ${response.status}`);
+ throw await toHttpResponseError(response);
+ }
+ regenerateAccepted = true;
+
+ // Only switch UI to regenerated placeholder messages after the backend accepts
+ // regenerate. This avoids local message loss when regenerate fails early (e.g. 400).
+ //
+ // When an explicit ``editFromPosition.fromMessageId`` is passed, slice from
+ // that message forward so edit-from-arbitrary-position drops every downstream
+ // message; otherwise fall back to the legacy "drop the last 2" behaviour.
+ setMessages((prev) => {
+ let base = prev;
+ if (editFromPosition?.fromMessageId != null) {
+ const targetId = `msg-${editFromPosition.fromMessageId}`;
+ const sliceIndex = prev.findIndex((m) => m.id === targetId);
+ if (sliceIndex >= 0) {
+ base = prev.slice(0, sliceIndex);
+ }
+ } else if (prev.length >= 2) {
+ base = prev.slice(0, -2);
+ }
+ return [
+ ...base,
+ userMessage,
+ {
+ id: assistantMsgId,
+ role: "assistant",
+ content: [{ type: "text", text: "" }],
+ createdAt: new Date(),
+ },
+ ];
+ });
+ if (sourceMentionedDocs.length > 0) {
+ setMessageDocumentsMap((prev) => ({
+ ...prev,
+ [userMsgId]: sourceMentionedDocs,
+ }));
}
const flushMessages = () => {
@@ -1869,111 +1989,41 @@ export default function NewChatPage() {
)
);
};
- const scheduleFlush = () => batcher.schedule(flushMessages);
- const forceFlush = () => {
- scheduleFlush();
- batcher.flush();
- };
+ const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages);
+ streamBatcher = batcher;
- for await (const parsed of readSSEStream(response)) {
- switch (parsed.type) {
- case "text-delta":
- appendText(contentPartsState, parsed.delta);
- scheduleFlush();
- break;
-
- case "reasoning-delta":
- appendReasoning(contentPartsState, parsed.delta);
- scheduleFlush();
- break;
-
- case "reasoning-end":
- endReasoning(contentPartsState);
- scheduleFlush();
- break;
-
- case "start-step":
- addStepSeparator(contentPartsState);
- scheduleFlush();
- break;
-
- case "finish-step":
- break;
-
- case "tool-input-start":
- addToolCall(
- contentPartsState,
- toolsWithUI,
- parsed.toolCallId,
- parsed.toolName,
- {},
- false,
- parsed.langchainToolCallId
- );
- forceFlush();
- break;
-
- case "tool-input-delta":
- appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
- scheduleFlush();
- break;
-
- case "tool-input-available": {
- const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
- if (toolCallIndices.has(parsed.toolCallId)) {
- updateToolCall(contentPartsState, parsed.toolCallId, {
- args: parsed.input || {},
- argsText: finalArgsText,
- langchainToolCallId: parsed.langchainToolCallId,
- });
- } else {
- addToolCall(
- contentPartsState,
- toolsWithUI,
- parsed.toolCallId,
- parsed.toolName,
- parsed.input || {},
- false,
- parsed.langchainToolCallId
- );
- updateToolCall(contentPartsState, parsed.toolCallId, {
- argsText: finalArgsText,
- });
- }
- forceFlush();
- break;
- }
-
- case "tool-output-available":
- updateToolCall(contentPartsState, parsed.toolCallId, {
- result: parsed.output,
- langchainToolCallId: parsed.langchainToolCallId,
- });
- markInterruptsCompleted(contentParts);
- if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
- const idx = toolCallIndices.get(parsed.toolCallId);
- if (idx !== undefined) {
- const part = contentParts[idx];
- if (part?.type === "tool-call" && part.toolName === "generate_podcast") {
- setActivePodcastTaskId(String(parsed.output.podcast_id));
+ await consumeSseEvents(response, async (parsed) => {
+ if (
+ processSharedStreamEvent(parsed, {
+ contentPartsState,
+ toolsWithUI,
+ currentThinkingSteps,
+ scheduleFlush,
+ forceFlush,
+ onTokenUsage: (data) => {
+ tokenUsageStore.set(assistantMsgId, data);
+ },
+ onTurnStatus: (data) => {
+ if (data.status === "cancelling") {
+ recentCancelRequestedAtRef.current = Date.now();
+ }
+ },
+ onToolOutputAvailable: (event, sharedCtx) => {
+ if (event.output?.status === "pending" && event.output?.podcast_id) {
+ const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
+ if (idx !== undefined) {
+ const part = sharedCtx.contentPartsState.contentParts[idx];
+ if (part?.type === "tool-call" && part.toolName === "generate_podcast") {
+ setActivePodcastTaskId(String(event.output.podcast_id));
+ }
}
}
- }
- forceFlush();
- break;
-
- case "data-thinking-step": {
- const stepData = parsed.data as ThinkingStepData;
- if (stepData?.id) {
- currentThinkingSteps.set(stepData.id, stepData);
- const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
- if (didUpdate) {
- scheduleFlush();
- }
- }
- break;
- }
-
+ },
+ })
+ ) {
+ return;
+ }
+ switch (parsed.type) {
case "data-action-log": {
if (threadId !== null) {
applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data);
@@ -1994,17 +2044,60 @@ export default function NewChatPage() {
}
case "data-turn-info": {
- streamedChatTurnId = parsed.data.chat_turn_id || null;
- if (streamedChatTurnId) {
+ const turnId = readStreamedChatTurnId(parsed.data);
+ if (turnId) {
setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
- )
+ applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
);
}
break;
}
+ case "data-user-message-id": {
+ // Same role as in ``onNew`` but the regenerate-specific
+ // mention metadata (``sourceMentionedDocs``) is the
+ // list to migrate onto the canonical id key.
+ const parsedMsg = readStreamedMessageId(parsed.data);
+ if (!parsedMsg) break;
+ const newUserMsgId = `msg-${parsedMsg.messageId}`;
+ const oldUserMsgId = userMsgId;
+ setMessages((prev) =>
+ prev.map((m) =>
+ m.id === oldUserMsgId
+ ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, parsedMsg.turnId)
+ : m
+ )
+ );
+ if (sourceMentionedDocs.length > 0) {
+ setMessageDocumentsMap((prev) => {
+ if (!(oldUserMsgId in prev)) {
+ return { ...prev, [newUserMsgId]: sourceMentionedDocs };
+ }
+ const { [oldUserMsgId]: _removed, ...rest } = prev;
+ return { ...rest, [newUserMsgId]: sourceMentionedDocs };
+ });
+ }
+ userMsgId = newUserMsgId;
+ break;
+ }
+
+ case "data-assistant-message-id": {
+ const parsedMsg = readStreamedMessageId(parsed.data);
+ if (!parsedMsg) break;
+ const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
+ const oldAssistantMsgId = assistantMsgId;
+ tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
+ setMessages((prev) =>
+ prev.map((m) =>
+ m.id === oldAssistantMsgId
+ ? mergeChatTurnIdIntoMessage({ ...m, id: newAssistantMsgId }, parsedMsg.turnId)
+ : m
+ )
+ );
+ assistantMsgId = newAssistantMsgId;
+ break;
+ }
+
case "data-revert-results": {
const summary = parsed.data;
// failureCount must include every "not undone" bucket
@@ -2040,95 +2133,48 @@ export default function NewChatPage() {
}
break;
}
-
- case "data-token-usage":
- tokenUsageData = parsed.data;
- tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
- break;
-
- case "error":
- throw new Error(parsed.errorText || "Server error");
}
- }
+ });
batcher.flush();
- // Persist messages after streaming completes
- const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
+ // Server-authoritative persistence: ``stream_new_chat``
+ // (regenerate flow) wrote the user row in
+ // ``persist_user_turn`` and finalises the assistant row
+ // in ``finalize_assistant_turn`` from a shielded
+ // ``finally`` block (covers both happy-path and
+ // abort-mid-stream). FE only needs to track the
+ // successful response here.
if (contentParts.length > 0) {
- try {
- // Persist user message (for both edit and reload modes, since backend deleted it)
- const userContentToPersist = isEdit
- ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }])
- : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }];
-
- const savedUserMessage = await appendMessage(threadId, {
- role: "user",
- content: userContentToPersist,
- turn_id: streamedChatTurnId,
- });
-
- // Update user message ID to database ID
- const newUserMsgId = `msg-${savedUserMessage.id}`;
- setMessages((prev) =>
- prev.map((m) =>
- m.id === userMsgId
- ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id)
- : m
- )
- );
-
- // Persist assistant message
- const savedMessage = await appendMessage(threadId, {
- role: "assistant",
- content: finalContent,
- token_usage: tokenUsageData ?? undefined,
- turn_id: streamedChatTurnId,
- });
-
- const newMsgId = `msg-${savedMessage.id}`;
- tokenUsageStore.rename(assistantMsgId, newMsgId);
- setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId
- ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
- : m
- )
- );
-
- trackChatResponseReceived(searchSpaceId, threadId);
- } catch (err) {
- console.error("Failed to persist regenerated message:", err);
- }
+ trackChatResponseReceived(searchSpaceId, threadId);
}
} catch (error) {
- if (error instanceof Error && error.name === "AbortError") {
- return;
- }
- batcher.dispose();
- console.error("[NewChatPage] Regeneration error:", error);
- trackChatError(
- searchSpaceId,
+ streamBatcher?.dispose();
+ await handleStreamTerminalError({
+ error,
+ flow: "regenerate",
threadId,
- error instanceof Error ? error.message : "Unknown error"
- );
- toast.error("Failed to regenerate response. Please try again.");
- setMessages((prev) =>
- prev.map((m) =>
- m.id === assistantMsgId
- ? {
- ...m,
- content: [{ type: "text", text: "Sorry, there was an error. Please try again." }],
- }
- : m
- )
- );
+ assistantMsgId,
+ accepted: regenerateAccepted,
+ });
} finally {
setIsRunning(false);
abortControllerRef.current = null;
}
},
- [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI]
+ [
+ threadId,
+ searchSpaceId,
+ messages,
+ disabledTools,
+ localFilesystemEnabled,
+ messageDocumentsMap,
+ setMessageDocumentsMap,
+ queryClient,
+ tokenUsageStore,
+ fetchWithTurnCancellingRetry,
+ handleStreamTerminalError,
+ ]
);
// Handle editing a message - truncates history and regenerates with new query.
@@ -2162,7 +2208,11 @@ export default function NewChatPage() {
if (fromMessageId == null) {
// No source id (or non-DB id) — fall back to today's
// last-2 behaviour. The user gets the legacy edit flow.
- await handleRegenerate(queryForApi, { userMessageContent, userImages });
+ await handleRegenerate(queryForApi, {
+ userMessageContent,
+ userImages,
+ sourceUserMessageId: sourceId,
+ });
return;
}
@@ -2211,7 +2261,7 @@ export default function NewChatPage() {
// Nothing to revert — submit silently.
await handleRegenerate(
queryForApi,
- { userMessageContent, userImages },
+ { userMessageContent, userImages, sourceUserMessageId: sourceId },
{ fromMessageId, revertActions: false }
);
return;
@@ -2246,6 +2296,7 @@ export default function NewChatPage() {
{
userMessageContent: pending.userMessageContent,
userImages: pending.userImages,
+ sourceUserMessageId: `msg-${pending.fromMessageId}`,
},
{
fromMessageId: pending.fromMessageId,
diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx
index 67d9edab0..85bc4aaa6 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx
@@ -1,11 +1,8 @@
"use client";
-import { useQueryClient } from "@tanstack/react-query";
import { CheckCircle2 } from "lucide-react";
import Link from "next/link";
import { useParams } from "next/navigation";
-import { useEffect } from "react";
-import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
import { Button } from "@/components/ui/button";
import {
Card,
@@ -18,14 +15,8 @@ import {
export default function PurchaseSuccessPage() {
const params = useParams();
- const queryClient = useQueryClient();
const searchSpaceId = String(params.search_space_id ?? "");
- useEffect(() => {
- void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY });
- void queryClient.invalidateQueries({ queryKey: ["token-status"] });
- }, [queryClient]);
-
return (
- SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that
- require focusing the app or the active application.
+ SurfSense needs two macOS permissions for Screenshot Assist and for desktop features
+ that require focusing the app or the active application.
)}
diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx
index 7f12d3cae..c42cb991e 100644
--- a/surfsense_web/components/editor/plate-editor.tsx
+++ b/surfsense_web/components/editor/plate-editor.tsx
@@ -8,9 +8,11 @@ import { useEffect, useMemo, useRef } from "react";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import { EditorSaveContext } from "@/components/editor/editor-save-context";
+import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit";
import { type EditorPreset, presetMap } from "@/components/editor/presets";
import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx";
import { Editor, EditorContainer } from "@/components/ui/editor";
+import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser";
/** Live editor instance returned by `usePlateEditor`. */
export type PlateEditorInstance = ReturnType;
@@ -65,6 +67,14 @@ export interface PlateEditorProps {
* without modifying the core editor component.
*/
extraPlugins?: AnyPluginConfig[];
+ /**
+ * Render `[citation:N]` and `[citation:URL]` tokens in the deserialized
+ * markdown as interactive citation badges/popovers (mirrors chat). Only
+ * meant for read-only views — when true, `onMarkdownChange` is suppressed
+ * because the in-memory tree contains custom inline-void elements that
+ * have no markdown serialize rule.
+ */
+ enableCitations?: boolean;
}
function PlateEditorContent({
@@ -103,6 +113,7 @@ export function PlateEditor({
defaultEditing = false,
preset = "full",
extraPlugins = [],
+ enableCitations = false,
}: PlateEditorProps) {
const lastMarkdownRef = useRef(markdown);
const lastHtmlRef = useRef(html);
@@ -145,6 +156,8 @@ export function PlateEditor({
...(onSave ? [SaveShortcutPlugin] : []),
// Consumer-provided extra plugins
...extraPlugins,
+ // Citation void inline element (read-only document viewer).
+ ...(enableCitations ? CitationKit : []),
MarkdownPlugin.configure({
options: {
remarkPlugins: [remarkGfm, remarkMath, remarkMdx],
@@ -154,8 +167,18 @@ export function PlateEditor({
value: html
? (editor) => editor.api.html.deserialize({ element: html }) as Value
: markdown
- ? (editor) =>
- editor.getApi(MarkdownPlugin).markdown.deserialize(escapeMdxExpressions(markdown))
+ ? (editor) => {
+ if (!enableCitations) {
+ return editor
+ .getApi(MarkdownPlugin)
+ .markdown.deserialize(escapeMdxExpressions(markdown));
+ }
+ const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown);
+ const value = editor
+ .getApi(MarkdownPlugin)
+ .markdown.deserialize(escapeMdxExpressions(rewritten));
+ return injectCitationNodes(value as Descendant[], urlMap) as Value;
+ }
: undefined,
});
@@ -174,13 +197,22 @@ export function PlateEditor({
useEffect(() => {
if (!html && markdown !== undefined && markdown !== lastMarkdownRef.current) {
lastMarkdownRef.current = markdown;
- const newValue = editor
- .getApi(MarkdownPlugin)
- .markdown.deserialize(escapeMdxExpressions(markdown));
+ let newValue: Descendant[];
+ if (enableCitations) {
+ const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown);
+ const deserialized = editor
+ .getApi(MarkdownPlugin)
+ .markdown.deserialize(escapeMdxExpressions(rewritten)) as Descendant[];
+ newValue = injectCitationNodes(deserialized, urlMap);
+ } else {
+ newValue = editor
+ .getApi(MarkdownPlugin)
+ .markdown.deserialize(escapeMdxExpressions(markdown)) as Descendant[];
+ }
editor.tf.reset();
- editor.tf.setValue(newValue);
+ editor.tf.setValue(newValue as Value);
}
- }, [html, markdown, editor]);
+ }, [html, markdown, editor, enableCitations]);
// When not forced read-only, the user can toggle between editing/viewing.
const canToggleMode = !readOnly && allowModeToggle;
@@ -205,6 +237,16 @@ export function PlateEditor({
// (initialized to true via usePlateEditor, toggled via ModeToolbarButton).
{...(readOnly ? { readOnly: true } : {})}
onChange={({ value }) => {
+ // View-only citation mode: skip serialization. The custom
+ // `citation` inline-void element has no markdown serialize
+ // rule, so emitting changes here would overwrite
+ // `lastMarkdownRef.current` (and downstream copy-to-clipboard
+ // state in EditorPanelContent) with a tree that loses every
+ // citation token. `enableCitations` is only ever set in
+ // read-only paths, so user input cannot reach this branch
+ // in practice — the guard exists for the initial Plate
+ // normalize emit.
+ if (enableCitations) return;
if (onHtmlChange && html) {
const serialized = slateToHtml(value as Descendant[]);
onHtmlChange(serialized);
diff --git a/surfsense_web/components/editor/plugins/citation-kit.tsx b/surfsense_web/components/editor/plugins/citation-kit.tsx
new file mode 100644
index 000000000..1908de209
--- /dev/null
+++ b/surfsense_web/components/editor/plugins/citation-kit.tsx
@@ -0,0 +1,218 @@
+"use client";
+
+import { type Descendant, KEYS } from "platejs";
+import { createPlatePlugin, type PlateElementProps } from "platejs/react";
+import type { FC } from "react";
+import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation";
+import {
+ CITATION_REGEX,
+ type CitationUrlMap,
+ parseTextWithCitations,
+} from "@/lib/citations/citation-parser";
+
+/**
+ * Plate inline-void node modeling a single `[citation:...]` reference.
+ *
+ * Modeled after the existing `MentionPlugin` pattern in
+ * `inline-mention-editor.tsx` — the only confirmed pattern in this repo
+ * for non-text inline UI. Inline-void elements satisfy Slate's invariant
+ * that the editor renders both atomic widgets and surrounding text
+ * cleanly without breaking selection / caret semantics.
+ */
+export type CitationElementNode = {
+ type: "citation";
+ kind: "chunk" | "doc" | "url";
+ chunkId?: number;
+ url?: string;
+ /** Original `[citation:...]` substring for traceability/debugging. */
+ rawText: string;
+ children: [{ text: "" }];
+};
+
+const CITATION_TYPE = "citation";
+
+const CitationElement: FC> = ({
+ attributes,
+ children,
+ element,
+}) => {
+ const isUrl = element.kind === "url";
+ return (
+
+
+ {isUrl && element.url ? (
+
+ ) : element.chunkId !== undefined ? (
+
+ ) : null}
+
+ {children}
+
+ );
+};
+
+const CitationPlugin = createPlatePlugin({
+ key: CITATION_TYPE,
+ node: {
+ isElement: true,
+ isInline: true,
+ isVoid: true,
+ type: CITATION_TYPE,
+ component: CitationElement,
+ },
+});
+
+/** Plugin kit shape used elsewhere in the editor. */
+export const CitationKit = [CitationPlugin];
+
+// ---------------------------------------------------------------------------
+// Slate value transform — runs after MarkdownPlugin.deserialize
+// ---------------------------------------------------------------------------
+
+// Structural shapes used by the value transform. We cannot use Plate's
+// generic Element / Text type predicates directly because `Descendant` is a
+// constrained union and our predicates would over-narrow. Casting through
+// these row types keeps the walker readable without fighting the types.
+type SlateText = { text: string } & Record;
+type SlateElement = { type?: string; children: Descendant[] } & Record;
+
+function isText(node: Descendant): boolean {
+ return typeof (node as { text?: unknown }).text === "string";
+}
+
+function asText(node: Descendant): SlateText {
+ return node as unknown as SlateText;
+}
+
+function asElement(node: Descendant): SlateElement {
+ return node as unknown as SlateElement;
+}
+
+/**
+ * Element types whose subtrees we MUST NOT inject citation void elements
+ * into. Each rationale documented in the citation plan:
+ * - `KEYS.codeBlock` / `code_line` — Plate's schema rejects inline elements
+ * inside code containers; the user expects literal text inside code.
+ * - `KEYS.link` — `
You've used all {limit.toLocaleString()} free tokens. Create a free account to
- get 3 million tokens and access to all models.
+ get $5 of premium credit and access to all models.
Create an account
{" "}
- for 5M free tokens.
+ for $5 of premium credit.