From efb5664afddee4279c3906113fffc8c3bf3a3f82 Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Thu, 30 Apr 2026 20:55:40 -0700
Subject: [PATCH 01/12] ci: add diagnostic notary-status workflow
Made-with: Cursor
---
.github/workflows/notary-status.yml | 60 +++++++++++++++++++++++++++++
1 file changed, 60 insertions(+)
create mode 100644 .github/workflows/notary-status.yml
diff --git a/.github/workflows/notary-status.yml b/.github/workflows/notary-status.yml
new file mode 100644
index 000000000..5c7c42038
--- /dev/null
+++ b/.github/workflows/notary-status.yml
@@ -0,0 +1,60 @@
+name: Notary status check
+
+# One-off diagnostic workflow. Queries Apple's notary service to see if your
+# submissions are queued, in progress, accepted, or rejected. Useful when a
+# notarization seems "hung" — most often the queue itself, especially on a
+# brand-new Apple Developer account.
+#
+# Run via: Actions tab -> "Notary status check" -> Run workflow.
+# Inputs are optional; if you provide a submission ID, it also fetches that
+# submission's full Apple log.
+#
+# Safe to delete after diagnosis.
+
+on:
+ workflow_dispatch:
+ inputs:
+ submission_id:
+ description: 'Optional: submission UUID to fetch full Apple log for'
+ required: false
+ default: ''
+
+jobs:
+ status:
+ runs-on: macos-latest
+ steps:
+ - name: List recent notarization submissions
+ env:
+ APPLE_ID: ${{ secrets.APPLE_ID }}
+ APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
+ APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
+ run: |
+ set -euo pipefail
+ echo "::group::Submission history (most recent first)"
+ xcrun notarytool history \
+ --apple-id "$APPLE_ID" \
+ --password "$APPLE_APP_SPECIFIC_PASSWORD" \
+ --team-id "$APPLE_TEAM_ID"
+ echo "::endgroup::"
+
+ - name: Inspect specific submission (if id provided)
+ if: ${{ inputs.submission_id != '' }}
+ env:
+ APPLE_ID: ${{ secrets.APPLE_ID }}
+ APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
+ APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
+ SUBMISSION_ID: ${{ inputs.submission_id }}
+ run: |
+ set -euo pipefail
+ echo "::group::Submission info"
+ xcrun notarytool info "$SUBMISSION_ID" \
+ --apple-id "$APPLE_ID" \
+ --password "$APPLE_APP_SPECIFIC_PASSWORD" \
+ --team-id "$APPLE_TEAM_ID"
+ echo "::endgroup::"
+ echo "::group::Apple's processing log for this submission"
+ xcrun notarytool log "$SUBMISSION_ID" \
+ --apple-id "$APPLE_ID" \
+ --password "$APPLE_APP_SPECIFIC_PASSWORD" \
+ --team-id "$APPLE_TEAM_ID" || true
+ echo "::endgroup::"
From ae9d36d77f26c4b74c65ed309fc0204dd4552b36 Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 14:34:23 -0700
Subject: [PATCH 02/12] feat: unified credits and its cost calculations
---
docker/.env.example | 28 +-
surfsense_backend/.env.example | 42 +-
.../140_premium_tokens_to_credit_micros.py | 291 ++++++++++
surfsense_backend/app/app.py | 2 +
surfsense_backend/app/celery_app.py | 2 +
surfsense_backend/app/config/__init__.py | 146 ++++-
.../app/config/global_llm_config.example.yaml | 29 +
surfsense_backend/app/db.py | 33 +-
.../app/etl_pipeline/etl_pipeline_service.py | 25 +-
.../app/routes/image_generation_routes.py | 154 +++++-
.../app/routes/new_chat_routes.py | 7 +-
.../app/routes/search_spaces_routes.py | 4 +
surfsense_backend/app/routes/stripe_routes.py | 57 +-
.../app/routes/vision_llm_routes.py | 7 +
.../app/schemas/image_generation.py | 18 +
surfsense_backend/app/schemas/new_chat.py | 1 +
surfsense_backend/app/schemas/stripe.py | 22 +-
surfsense_backend/app/schemas/vision_llm.py | 32 ++
.../app/services/billable_calls.py | 430 +++++++++++++++
.../app/services/llm_router_service.py | 58 +-
surfsense_backend/app/services/llm_service.py | 30 +-
.../openrouter_integration_service.py | 319 +++++++++++
.../app/services/pricing_registration.py | 274 ++++++++++
.../app/services/provider_api_base.py | 107 ++++
.../app/services/quota_checked_vision_llm.py | 105 ++++
.../app/services/token_quota_service.py | 125 ++++-
.../app/services/token_tracking_service.py | 239 +++++++-
.../app/services/vision_llm_router_service.py | 16 +-
.../app/tasks/celery_tasks/podcast_tasks.py | 67 ++-
.../celery_tasks/video_presentation_tasks.py | 68 ++-
.../app/tasks/chat/stream_new_chat.py | 112 ++--
.../tests/unit/routes/test_image_gen_quota.py | 138 +++++
.../services/test_agent_billing_resolver.py | 436 +++++++++++++++
.../tests/unit/services/test_billable_call.py | 432 +++++++++++++++
.../test_openrouter_integration_service.py | 156 ++++++
.../services/test_pricing_registration.py | 447 +++++++++++++++
.../services/test_quota_checked_vision_llm.py | 157 ++++++
.../services/test_token_quota_service_cost.py | 515 ++++++++++++++++++
.../tests/unit/tasks/test_podcast_billing.py | 325 +++++++++++
.../tasks/test_video_presentation_billing.py | 330 +++++++++++
surfsense_web/app/(home)/free/page.tsx | 4 +-
surfsense_web/app/(home)/pricing/page.tsx | 2 +-
.../[search_space_id]/buy-more/page.tsx | 2 +-
.../components/PurchaseHistoryContent.tsx | 27 +-
surfsense_web/atoms/user/user-query.atoms.ts | 6 +-
.../assistant-ui/assistant-message.tsx | 20 +
.../assistant-ui/token-usage-context.tsx | 21 +-
.../free-chat/quota-warning-banner.tsx | 4 +-
.../ui/sidebar/PremiumTokenUsageDisplay.tsx | 25 +-
.../components/pricing/pricing-section.tsx | 34 +-
.../settings/buy-tokens-content.tsx | 47 +-
.../settings/image-model-manager.tsx | 20 +-
.../components/settings/llm-role-manager.tsx | 27 +-
.../settings/vision-model-manager.tsx | 20 +-
surfsense_web/contexts/login-gate.tsx | 4 +-
.../contracts/types/new-llm-config.types.ts | 6 +
surfsense_web/contracts/types/stripe.types.ts | 15 +-
.../lib/chat/chat-error-classifier.ts | 2 +-
surfsense_web/lib/chat/streaming-state.ts | 9 +-
surfsense_web/lib/chat/thread-persistence.ts | 13 +-
surfsense_web/zero/schema/user.ts | 13 +-
61 files changed, 5835 insertions(+), 272 deletions(-)
create mode 100644 surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
create mode 100644 surfsense_backend/app/services/billable_calls.py
create mode 100644 surfsense_backend/app/services/pricing_registration.py
create mode 100644 surfsense_backend/app/services/provider_api_base.py
create mode 100644 surfsense_backend/app/services/quota_checked_vision_llm.py
create mode 100644 surfsense_backend/tests/unit/routes/test_image_gen_quota.py
create mode 100644 surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
create mode 100644 surfsense_backend/tests/unit/services/test_billable_call.py
create mode 100644 surfsense_backend/tests/unit/services/test_pricing_registration.py
create mode 100644 surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
create mode 100644 surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
create mode 100644 surfsense_backend/tests/unit/tasks/test_podcast_billing.py
create mode 100644 surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
diff --git a/docker/.env.example b/docker/.env.example
index 95de0cf85..c2e87a619 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
# STRIPE_RECONCILIATION_BATCH_SIZE=100
-# Premium token purchases ($1 per 1M tokens for premium-tier models)
+# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
+# credit; premium turns debit the actual per-call provider cost
+# reported by LiteLLM, so cheap and expensive models bill proportionally)
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
-# STRIPE_TOKENS_PER_UNIT=1000000
+# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
+# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
# ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text)
@@ -315,9 +318,24 @@ STT_SERVICE=local/base
# Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500
-# Premium token quota per registered user (default: 5M)
-# Only applies to models with billing_tier=premium in global_llm_config.yaml
-# PREMIUM_TOKEN_LIMIT=5000000
+# Premium credit quota per registered user, in micro-USD (default: $5).
+# Premium turns are debited at the actual per-call provider cost reported
+# by LiteLLM. Only applies to models with billing_tier=premium.
+# PREMIUM_CREDIT_MICROS_LIMIT=5000000
+# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
+
+# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
+# QUOTA_MAX_RESERVE_MICROS=1000000
+
+# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
+# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
+
+# Per-podcast reservation for the podcast Celery task ($0.20 default).
+# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
+
+# Per-video-presentation reservation for the video Celery task ($1.00 default).
+# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
+# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — public users can chat without an account
# Set TRUE to enable /free pages and anonymous chat API
diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example
index a793f33d1..1b1478ae6 100644
--- a/surfsense_backend/.env.example
+++ b/surfsense_backend/.env.example
@@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
# Set FALSE to disable new checkout session creation temporarily
STRIPE_PAGE_BUYING_ENABLED=TRUE
-# Premium token purchases via Stripe (for premium-tier model usage)
-# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
+# Premium credit purchases via Stripe (for premium-tier model usage).
+# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
+# (default 1_000_000 = $1.00). Premium turns are billed at the actual
+# per-call provider cost reported by LiteLLM.
STRIPE_TOKEN_BUYING_ENABLED=FALSE
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
-STRIPE_TOKENS_PER_UNIT=1000000
+STRIPE_CREDIT_MICROS_PER_UNIT=1000000
+# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
+# STRIPE_TOKENS_PER_UNIT=1000000
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
@@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
PAGES_LIMIT=500
-# Premium token quota per registered user (default: 3,000,000)
-# Applies only to models with billing_tier=premium in global_llm_config.yaml
-PREMIUM_TOKEN_LIMIT=3000000
+# Premium credit quota per registered user, in micro-USD
+# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
+# actual per-call provider cost reported by LiteLLM, so cheap and expensive
+# models bill proportionally. Applies only to models with
+# billing_tier=premium in global_llm_config.yaml.
+PREMIUM_CREDIT_MICROS_LIMIT=5000000
+# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
+# PREMIUM_TOKEN_LIMIT=5000000
+
+# Safety ceiling on per-call premium reservation, in micro-USD.
+# stream_new_chat estimates an upper-bound cost from the model's
+# litellm-published per-token rates × the config's quota_reserve_tokens
+# and clamps to this value so a misconfigured model can't lock the
+# user's whole balance on one call. Default $1.00.
+QUOTA_MAX_RESERVE_MICROS=1000000
+
+# Per-image reservation (in micro-USD) for the POST /image-generations
+# endpoint. Bypassed for free configs. Default $0.05.
+QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
+
+# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
+# Single envelope covers one transcript-generation LLM call. Default $0.20.
+QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
+
+# Per-video-presentation reservation (in micro-USD) used by the video
+# presentation Celery task. Covers worst-case fan-out of N slide-scene
+# generations + refines. Default $1.00. NOTE: tasks using the override
+# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
+QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — allows public users to chat without an account
# Set TRUE to enable /free pages and anonymous chat API
diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
new file mode 100644
index 000000000..64aa699e8
--- /dev/null
+++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
@@ -0,0 +1,291 @@
+"""rename premium token columns to credit micros and add cost_micros to token_usage
+
+Migrates the premium quota system from a flat token counter to a USD-cost
+based credit system, where 1 credit = 1 micro-USD ($0.000001).
+
+Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
+price means every existing value is already correct in the new unit, no
+data transformation needed):
+
+ user.premium_tokens_limit -> premium_credit_micros_limit
+ user.premium_tokens_used -> premium_credit_micros_used
+ user.premium_tokens_reserved -> premium_credit_micros_reserved
+
+ premium_token_purchases.tokens_granted -> credit_micros_granted
+
+New column for cost auditing per turn:
+
+ token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
+
+The "user" table is in zero_publication's column list (added in 139), so
+this migration must drop and recreate the publication with the renamed
+column names, otherwise zero-cache will replicate stale column names and
+the FE Zero schema will fail to bind.
+
+IMPORTANT - before AND after running this migration:
+ 1. Stop zero-cache (it holds replication locks that will deadlock DDL)
+ 2. Run: alembic upgrade head
+ 3. Delete / reset the zero-cache data volume
+ 4. Restart zero-cache (it will do a fresh initial sync)
+
+Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
+"user". Skipping the data-volume reset will leave IndexedDB clients seeing
+column-not-found errors from a stale catalog snapshot.
+
+Revision ID: 140
+Revises: 139
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+
+from alembic import op
+
+revision: str = "140"
+down_revision: str | None = "139"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+PUBLICATION_NAME = "zero_publication"
+
+# Replicates 139's document column list verbatim — must stay in sync.
+DOCUMENT_COLS = [
+ "id",
+ "title",
+ "document_type",
+ "search_space_id",
+ "folder_id",
+ "created_by_id",
+ "status",
+ "created_at",
+ "updated_at",
+]
+
+# Same five live-meter fields as 139, with the renamed column names.
+USER_COLS = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_credit_micros_limit",
+ "premium_credit_micros_used",
+]
+
+
+def _terminate_blocked_pids(conn, table: str) -> None:
+ """Kill backends whose locks on *table* would block our AccessExclusiveLock."""
+ conn.execute(
+ sa.text(
+ "SELECT pg_terminate_backend(l.pid) "
+ "FROM pg_locks l "
+ "JOIN pg_class c ON c.oid = l.relation "
+ "WHERE c.relname = :tbl "
+ " AND l.pid != pg_backend_pid()"
+ ),
+ {"tbl": table},
+ )
+
+
+def _has_zero_version(conn, table: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = '_0_version'"
+ ),
+ {"tbl": table},
+ ).fetchone()
+ is not None
+ )
+
+
+def _column_exists(conn, table: str, column: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = :col"
+ ),
+ {"tbl": table, "col": column},
+ ).fetchone()
+ is not None
+ )
+
+
+def _build_publication_ddl(
+ user_cols: list[str],
+ *,
+ documents_has_zero_ver: bool,
+ user_has_zero_ver: bool,
+) -> str:
+ doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
+ user_col_list_with_meta = user_cols + (
+ ['"_0_version"'] if user_has_zero_ver else []
+ )
+ doc_col_list = ", ".join(doc_cols)
+ user_col_list = ", ".join(user_col_list_with_meta)
+ return (
+ f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
+ f"notifications, "
+ f"documents ({doc_col_list}), "
+ f"folders, "
+ f"search_source_connectors, "
+ f"new_chat_messages, "
+ f"chat_comments, "
+ f"chat_session_state, "
+ f'"user" ({user_col_list})'
+ )
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+
+ # ------------------------------------------------------------------
+ # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
+ # dev environments are safe.
+ # ------------------------------------------------------------------
+ if not _column_exists(conn, "token_usage", "cost_micros"):
+ op.add_column(
+ "token_usage",
+ sa.Column(
+ "cost_micros",
+ sa.BigInteger(),
+ nullable=False,
+ server_default="0",
+ ),
+ )
+
+ # ------------------------------------------------------------------
+ # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
+ # ------------------------------------------------------------------
+ if _column_exists(
+ conn, "premium_token_purchases", "tokens_granted"
+ ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
+ op.alter_column(
+ "premium_token_purchases",
+ "tokens_granted",
+ new_column_name="credit_micros_granted",
+ )
+
+ # ------------------------------------------------------------------
+ # 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
+ #
+ # We must drop the publication first (it references the old column
+ # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
+ # in a transaction block; alembic's outer transaction already holds
+ # one, but a SAVEPOINT keeps the LOCK + DDL atomic.
+ # ------------------------------------------------------------------
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
+ # publications require at least the PK to be in the column list,
+ # which is true for both the old and new shape.
+ conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
+
+ # Drop the publication BEFORE renaming columns, otherwise Postgres
+ # rejects the rename: "cannot drop column ... referenced by
+ # publication".
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ for old, new in (
+ ("premium_tokens_limit", "premium_credit_micros_limit"),
+ ("premium_tokens_used", "premium_credit_micros_used"),
+ ("premium_tokens_reserved", "premium_credit_micros_reserved"),
+ ):
+ if _column_exists(conn, "user", old) and not _column_exists(
+ conn, "user", new
+ ):
+ op.alter_column("user", old, new_column_name=new)
+
+ # Update the server_default on the renamed limit column so newly
+ # inserted users get $5 of credit (== 5_000_000 micros) by
+ # default. Existing rows are unaffected.
+ op.alter_column(
+ "user",
+ "premium_credit_micros_limit",
+ server_default="5000000",
+ )
+
+ # Recreate the publication with the new column names.
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+ conn.execute(
+ sa.text(
+ _build_publication_ddl(
+ USER_COLS,
+ documents_has_zero_ver=documents_has_zero_ver,
+ user_has_zero_ver=user_has_zero_ver,
+ )
+ )
+ )
+
+
+def downgrade() -> None:
+ """Revert the rename and drop ``cost_micros``.
+
+ Mirrors ``upgrade``: drop the publication, rename columns back, drop
+ the new column, recreate the publication with the old column list.
+ Same zero-cache stop/reset runbook applies in reverse.
+ """
+ conn = op.get_bind()
+
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ for new, old in (
+ ("premium_credit_micros_limit", "premium_tokens_limit"),
+ ("premium_credit_micros_used", "premium_tokens_used"),
+ ("premium_credit_micros_reserved", "premium_tokens_reserved"),
+ ):
+ if _column_exists(conn, "user", new) and not _column_exists(
+ conn, "user", old
+ ):
+ op.alter_column("user", new, new_column_name=old)
+
+ op.alter_column(
+ "user",
+ "premium_tokens_limit",
+ server_default="5000000",
+ )
+
+ legacy_user_cols = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_tokens_limit",
+ "premium_tokens_used",
+ ]
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+ conn.execute(
+ sa.text(
+ _build_publication_ddl(
+ legacy_user_cols,
+ documents_has_zero_ver=documents_has_zero_ver,
+ user_has_zero_ver=user_has_zero_ver,
+ )
+ )
+ )
+
+ if _column_exists(
+ conn, "premium_token_purchases", "credit_micros_granted"
+ ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
+ op.alter_column(
+ "premium_token_purchases",
+ "credit_micros_granted",
+ new_column_name="tokens_granted",
+ )
+
+ if _column_exists(conn, "token_usage", "cost_micros"):
+ op.drop_column("token_usage", "cost_micros")
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 016c2de42..14d7f4d23 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
+ initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
@@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables()
initialize_openrouter_integration()
_start_openrouter_background_refresh()
+ initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py
index 58a8b0f39..74710d5e1 100644
--- a/surfsense_backend/app/celery_app.py
+++ b/surfsense_backend/app/celery_app.py
@@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
+ initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
+ initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py
index 675b05d2c..2aeeafb34 100644
--- a/surfsense_backend/app/config/__init__.py
+++ b/surfsense_backend/app/config/__init__.py
@@ -138,7 +138,11 @@ def load_global_image_gen_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
- return data.get("global_image_generation_configs", [])
+ configs = data.get("global_image_generation_configs", []) or []
+ for cfg in configs:
+ if isinstance(cfg, dict):
+ cfg.setdefault("billing_tier", "free")
+ return configs
except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}")
return []
@@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
- return data.get("global_vision_llm_configs", [])
+ configs = data.get("global_vision_llm_configs", []) or []
+ for cfg in configs:
+ if isinstance(cfg, dict):
+ cfg.setdefault("billing_tier", "free")
+ return configs
except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}")
return []
@@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
"anonymous_enabled_free", settings["anonymous_enabled"]
)
+ # Image generation + vision LLM emission are opt-in (issue L).
+ # OpenRouter's catalogue contains hundreds of image / vision
+ # capable models; auto-injecting all of them into every
+ # deployment would explode the model selector and surprise
+ # operators upgrading from prior versions. Default to False so
+ # admins must explicitly turn them on.
+ settings.setdefault("image_generation_enabled", False)
+ settings.setdefault("vision_enabled", False)
+
return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
@@ -296,10 +313,60 @@ def initialize_openrouter_integration():
)
else:
print("Info: OpenRouter integration enabled but no models fetched")
+
+ # Image generation + vision LLM emissions are opt-in (issue L).
+ # Both reuse the catalogue already cached by ``service.initialize``
+ # so we don't make additional network calls here.
+ if settings.get("image_generation_enabled"):
+ try:
+ image_configs = service.get_image_generation_configs()
+ if image_configs:
+ config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
+ print(
+ f"Info: OpenRouter integration added {len(image_configs)} "
+ f"image-generation models"
+ )
+ except Exception as e:
+ print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
+
+ if settings.get("vision_enabled"):
+ try:
+ vision_configs = service.get_vision_llm_configs()
+ if vision_configs:
+ config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
+ print(
+ f"Info: OpenRouter integration added {len(vision_configs)} "
+ f"vision LLM models"
+ )
+ except Exception as e:
+ print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
+def initialize_pricing_registration():
+ """
+ Teach LiteLLM the per-token cost of every deployment in
+ ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
+ from the OpenRouter catalogue + any operator-declared YAML pricing).
+
+ Must run AFTER ``initialize_openrouter_integration()`` so the
+ OpenRouter catalogue is populated and BEFORE the first LLM call so
+ ``response_cost`` is available in ``TokenTrackingCallback``.
+
+ Failures are logged but never raised — startup must not be blocked
+ by a missing pricing entry; the worst-case is the model debits 0.
+ """
+ try:
+ from app.services.pricing_registration import (
+ register_pricing_from_global_configs,
+ )
+
+ register_pricing_from_global_configs()
+ except Exception as e:
+ print(f"Warning: Failed to register LiteLLM pricing: {e}")
+
+
def initialize_llm_router():
"""
Initialize the LLM Router service for Auto mode.
@@ -444,14 +511,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
- # Premium token quota settings
- PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
+ # Premium credit (micro-USD) quota settings.
+ #
+ # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
+ # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
+ # still honoured for one release as fall-back values — the prior
+ # $1-per-1M-tokens Stripe price means every existing value maps 1:1
+ # to micros, so operators upgrading without changing their .env still
+ # get correct behaviour. A startup deprecation warning fires below if
+ # they're set.
+ PREMIUM_CREDIT_MICROS_LIMIT = int(
+ os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
+ or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
+ )
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
- STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
+ STRIPE_CREDIT_MICROS_PER_UNIT = int(
+ os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
+ or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
+ )
STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
)
+ # Safety ceiling on the per-call premium reservation. ``stream_new_chat``
+ # estimates an upper-bound cost from ``litellm.get_model_info`` x the
+ # config's ``quota_reserve_tokens`` and clamps the result to this value
+ # so a misconfigured "$1000/M" model can't lock the user's whole balance
+ # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
+ # reserve_tokens ≈ $0.36) with headroom.
+ QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
+
+ if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
+ "PREMIUM_CREDIT_MICROS_LIMIT"
+ ):
+ print(
+ "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
+ "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
+ "current Stripe price). The old key will be removed in a "
+ "future release."
+ )
+ if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
+ "STRIPE_CREDIT_MICROS_PER_UNIT"
+ ):
+ print(
+ "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
+ "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
+ "The old key will be removed in a future release."
+ )
+
# Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
@@ -464,6 +571,35 @@ class Config:
# Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
+ # Per-image reservation (in micro-USD) used by ``billable_call`` for the
+ # ``POST /image-generations`` endpoint when the global config does not
+ # override it. $0.05 covers realistic worst-cases for current OpenAI /
+ # OpenRouter image-gen pricing. Bypassed entirely for free configs.
+ QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
+ )
+
+ # Per-podcast reservation (in micro-USD). One agent LLM call generating
+ # a transcript, typically 5k-20k completion tokens. $0.20 covers a long
+ # premium-model run. Tune via env.
+ QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
+ )
+
+ # Per-video-presentation reservation (in micro-USD). Fan-out of N
+ # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
+ # plus refine retries; can produce many premium completions. $1.00
+ # covers worst-case. Tune via env.
+ #
+ # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
+ # 1_000_000. The override path in ``billable_call`` bypasses the
+ # per-call clamp in ``estimate_call_reserve_micros``, so this is the
+ # *actual* hold — raising it via env is fine but means a single video
+ # task can lock $1+ of credit.
+ QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
+ )
+
# Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml
index 79cbe1e51..d92640c8d 100644
--- a/surfsense_backend/app/config/global_llm_config.example.yaml
+++ b/surfsense_backend/app/config/global_llm_config.example.yaml
@@ -19,6 +19,24 @@
# Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
+#
+# COST-BASED PREMIUM CREDITS:
+# Each premium config bills the user's USD-credit balance based on the
+# actual provider cost reported by LiteLLM. For models LiteLLM already
+# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
+# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
+# or any model LiteLLM doesn't have in its built-in pricing table, declare
+# per-token costs inline so they bill correctly:
+#
+# litellm_params:
+# base_model: "my-custom-azure-deploy"
+# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
+# input_cost_per_token: 0.000003
+# output_cost_per_token: 0.000015
+#
+# OpenRouter dynamic models pull pricing automatically from OpenRouter's
+# API — no inline declaration needed. Models without resolvable pricing
+# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models
@@ -292,6 +310,17 @@ openrouter_integration:
free_rpm: 20
free_tpm: 100000
+ # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
+ # contains hundreds of image- and vision-capable models; turning these on
+ # injects them into the global Image-Generation / Vision-LLM model
+ # selectors alongside any static configs. Tier (free/premium) is derived
+ # per model the same way it is for chat (`:free` suffix or zero pricing).
+ # When a user picks a premium image/vision model the call debits the
+ # shared $5 USD-cost-based premium credit pool — so leaving these off
+ # avoids surprise quota burn on existing deployments. Default: false.
+ image_generation_enabled: false
+ vision_enabled: false
+
litellm_params:
max_tokens: 16384
system_instructions: ""
diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py
index 2fe478d9b..aef959ec9 100644
--- a/surfsense_backend/app/db.py
+++ b/surfsense_backend/app/db.py
@@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)
+ cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True)
@@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin):
- """Tracks Stripe checkout sessions used to grant additional premium token credits."""
+ """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
+
+ Note: the table name is preserved (``premium_token_purchases``) for
+ operational continuity even though the unit is now USD micro-credits
+ instead of raw tokens. The ``credit_micros_granted`` column replaced
+ the legacy ``tokens_granted`` in migration 140; the stored values
+ were not transformed because the prior $1 = 1M tokens Stripe price
+ makes the unit conversion 1:1 numerically.
+ """
__tablename__ = "premium_token_purchases"
__allow_unmapped__ = True
@@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
)
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False)
- tokens_granted = Column(BigInteger, nullable=False)
+ credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True)
status = Column(
@@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
- premium_tokens_limit = Column(
+ premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
- default=config.PREMIUM_TOKEN_LIMIT,
- server_default=str(config.PREMIUM_TOKEN_LIMIT),
+ default=config.PREMIUM_CREDIT_MICROS_LIMIT,
+ server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
- premium_tokens_used = Column(
+ premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
- premium_tokens_reserved = Column(
+ premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
@@ -2241,16 +2250,16 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
- premium_tokens_limit = Column(
+ premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
- default=config.PREMIUM_TOKEN_LIMIT,
- server_default=str(config.PREMIUM_TOKEN_LIMIT),
+ default=config.PREMIUM_CREDIT_MICROS_LIMIT,
+ server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
- premium_tokens_used = Column(
+ premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
- premium_tokens_reserved = Column(
+ premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
index 4bb38b7b0..d45bd780c 100644
--- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
+++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
@@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM",
content_type="image",
)
- except Exception:
- logging.warning(
- "Vision LLM failed for %s, falling back to document parser",
- request.filename,
- exc_info=True,
- )
+ except Exception as exc:
+ # Special-case quota exhaustion so we log a clearer message
+ # — the vision LLM didn't "fail", the user just ran out of
+ # premium credit. Falling through to the document parser
+ # is a graceful degradation: OCR/Unstructured still
+ # extracts text from the image without burning credit.
+ from app.services.billable_calls import QuotaInsufficientError
+
+ if isinstance(exc, QuotaInsufficientError):
+ logging.info(
+ "Vision LLM quota exhausted for %s; falling back to document parser",
+ request.filename,
+ )
+ else:
+ logging.warning(
+ "Vision LLM failed for %s, falling back to document parser",
+ request.filename,
+ exc_info=True,
+ )
else:
logging.info(
"No vision LLM provided, falling back to document parser for %s",
diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py
index 97a3559b9..34ed80207 100644
--- a/surfsense_backend/app/routes/image_generation_routes.py
+++ b/surfsense_backend/app/routes/image_generation_routes.py
@@ -36,6 +36,11 @@ from app.schemas import (
ImageGenerationListRead,
ImageGenerationRead,
)
+from app.services.billable_calls import (
+ DEFAULT_IMAGE_RESERVE_MICROS,
+ QuotaInsufficientError,
+ billable_call,
+)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
@@ -92,6 +97,50 @@ def _build_model_string(
return f"{prefix}/{model_name}"
+async def _resolve_billing_for_image_gen(
+ session: AsyncSession,
+ config_id: int | None,
+ search_space: SearchSpace,
+) -> tuple[str, str, int]:
+ """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
+
+ The resolution mirrors ``_execute_image_generation``'s lookup tree but
+ only extracts the fields needed for billing — we do this *before*
+ ``billable_call`` so the reservation is correctly sized for the
+ config that will actually run, and so we don't open an
+ ``ImageGeneration`` row for a request that's about to 402.
+
+ User-owned (positive ID) BYOK configs are always free — they cost
+ the user nothing on our side. Auto mode currently treats as free
+ because the underlying router can dispatch to either premium or
+ free YAML configs and we don't surface the resolved deployment up
+ here yet. Bringing Auto under premium billing would require
+ threading the chosen deployment back from ``ImageGenRouterService``.
+ """
+ resolved_id = config_id
+ if resolved_id is None:
+ resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
+
+ if is_image_gen_auto_mode(resolved_id):
+ return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
+
+ if resolved_id < 0:
+ cfg = _get_global_image_gen_config(resolved_id) or {}
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
+ base_model = _build_model_string(
+ cfg.get("provider", ""),
+ cfg.get("model_name", ""),
+ cfg.get("custom_provider"),
+ )
+ reserve_micros = int(
+ cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
+ )
+ return (billing_tier, base_model, reserve_micros)
+
+ # Positive ID = user-owned BYOK image-gen config — always free.
+ return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
+
+
async def _execute_image_generation(
session: AsyncSession,
image_gen: ImageGeneration,
@@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
+ # Auto mode currently treated as free until per-deployment
+ # billing-tier surfacing lands (see _resolve_billing_for_image_gen).
+ "billing_tier": "free",
}
)
@@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
+ "quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
@@ -454,7 +508,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
- """Create and execute an image generation request."""
+ """Create and execute an image generation request.
+
+ Premium configs are gated by the user's shared premium credit pool.
+ The flow is:
+
+ 1. Permission check + load the search space (cheap, no provider call).
+ 2. Resolve which config will run so we know its billing tier and the
+ worst-case reservation size *before* opening any DB rows.
+ 3. Wrap the entire ImageGeneration row insert + provider call in
+ ``billable_call``. If quota is denied, ``billable_call`` raises
+ ``QuotaInsufficientError`` *before* we flush a row, which we
+ translate to HTTP 402 (no orphaned rows on the user's account,
+ no inserted error rows for "you ran out of credit").
+ 4. On success, the actual ``response_cost`` flows through the
+ LiteLLM callback into the accumulator, and ``billable_call``
+ finalizes the debit at exit. Inner ``try/except`` still catches
+ provider errors and stores them on ``error_message`` (HTTP 200
+ with ``error_message`` set is preserved for failed-but-not-quota
+ scenarios — clients already know how to surface those).
+ """
try:
await check_permission(
session,
@@ -471,33 +544,70 @@ async def create_image_generation(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
- db_image_gen = ImageGeneration(
- prompt=data.prompt,
- model=data.model,
- n=data.n,
- quality=data.quality,
- size=data.size,
- style=data.style,
- response_format=data.response_format,
- image_generation_config_id=data.image_generation_config_id,
- search_space_id=data.search_space_id,
- created_by_id=user.id,
+ billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
+ session, data.image_generation_config_id, search_space
)
- session.add(db_image_gen)
- await session.flush()
- try:
- await _execute_image_generation(session, db_image_gen, search_space)
- except Exception as e:
- logger.exception("Image generation call failed")
- db_image_gen.error_message = str(e)
+ # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
+ # propagates to the outer ``except QuotaInsufficientError`` handler
+ # below as HTTP 402 — it is intentionally NOT swallowed into
+ # ``error_message`` because that would (1) imply a successful row
+ # exists when none does, and (2) return HTTP 200 to a client
+ # whose request was actively *denied* (issue K).
+ async with billable_call(
+ user_id=search_space.user_id,
+ search_space_id=data.search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=reserve_micros,
+ usage_type="image_generation",
+ call_details={"model": base_model, "prompt": data.prompt[:100]},
+ ):
+ db_image_gen = ImageGeneration(
+ prompt=data.prompt,
+ model=data.model,
+ n=data.n,
+ quality=data.quality,
+ size=data.size,
+ style=data.style,
+ response_format=data.response_format,
+ image_generation_config_id=data.image_generation_config_id,
+ search_space_id=data.search_space_id,
+ created_by_id=user.id,
+ )
+ session.add(db_image_gen)
+ await session.flush()
- await session.commit()
- await session.refresh(db_image_gen)
- return db_image_gen
+ try:
+ await _execute_image_generation(session, db_image_gen, search_space)
+ except Exception as e:
+ logger.exception("Image generation call failed")
+ db_image_gen.error_message = str(e)
+
+ await session.commit()
+ await session.refresh(db_image_gen)
+ return db_image_gen
except HTTPException:
raise
+ except QuotaInsufficientError as exc:
+ # The user's premium credit pool is empty. No DB row is created
+ # because ``billable_call`` denies before yielding (issue K).
+ await session.rollback()
+ raise HTTPException(
+ status_code=402,
+ detail={
+ "error_code": "premium_quota_exhausted",
+ "usage_type": exc.usage_type,
+ "used_micros": exc.used_micros,
+ "limit_micros": exc.limit_micros,
+ "remaining_micros": exc.remaining_micros,
+ "message": (
+ "Out of premium credits for image generation. "
+ "Purchase additional credits or switch to a free model."
+ ),
+ },
+ ) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py
index 28b197ca2..d3bd51129 100644
--- a/surfsense_backend/app/routes/new_chat_routes.py
+++ b/surfsense_backend/app/routes/new_chat_routes.py
@@ -1366,7 +1366,11 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
- # Persist token usage if provided (for assistant messages)
+ # Persist token usage if provided (for assistant messages).
+ # ``cost_micros`` is the provider USD cost reported by LiteLLM,
+ # forwarded by the FE through the appendMessage round-trip so
+ # the historical TokenUsage row matches the credit debit applied
+ # at finalize time.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
@@ -1377,6 +1381,7 @@ async def append_message(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
+ cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py
index 72715ea5b..5ecfb1814 100644
--- a/surfsense_backend/app/routes/search_spaces_routes.py
+++ b/surfsense_backend/app/routes/search_spaces_routes.py
@@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
+ "billing_tier": "free",
}
if config_id < 0:
@@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
}
return None
@@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
+ "billing_tier": "free",
}
if config_id < 0:
@@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
}
return None
diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py
index cfdd4b52a..aed74ec8d 100644
--- a/surfsense_backend/app/routes/stripe_routes.py
+++ b/surfsense_backend/app/routes/stripe_routes.py
@@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
- tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
+ # Read the new metadata key first, fall back to the legacy one so
+ # in-flight checkout sessions created before the cost-credits
+ # release still fulfil correctly (the unit is numerically the
+ # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
+ credit_micros_per_unit = int(
+ metadata.get("credit_micros_per_unit")
+ or metadata.get("tokens_per_unit", "0")
+ )
- if not user_id or quantity <= 0 or tokens_per_unit <= 0:
+ if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
@@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
- tokens_granted=quantity * tokens_per_unit,
+ credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
- user.premium_tokens_limit = (
- max(user.premium_tokens_used, user.premium_tokens_limit)
- + purchase.tokens_granted
+ # Top up the user's credit balance by the granted micro-USD amount.
+ # ``max(used, limit)`` clamps the case where the legacy code wrote a
+ # used value above the limit (e.g. underbilling rounding) so adding
+ # ``credit_micros_granted`` always lifts the limit by the full pack
+ # size rather than disappearing into past overuse.
+ user.premium_credit_micros_limit = (
+ max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ + purchase.credit_micros_granted
)
await db_session.commit()
@@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
- """Create a Stripe Checkout Session for buying premium token packs."""
+ """Create a Stripe Checkout Session for buying premium credit packs.
+
+ Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
+ credit (default 1_000_000 = $1.00). The user's balance is debited
+ at the actual provider cost reported by LiteLLM at finalize time,
+ so $1 of credit always buys $1 worth of provider usage at cost.
+ """
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
- tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
+ credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
@@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
- "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
- "purchase_type": "premium_tokens",
+ "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
+ "purchase_type": "premium_credit",
},
}
)
@@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
- tokens_granted=tokens_granted,
+ credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status(
user: User = Depends(current_active_user),
):
- """Return token-buying availability and current premium quota for frontend."""
- used = user.premium_tokens_used
- limit = user.premium_tokens_limit
+ """Return token-buying availability and current premium credit quota for frontend.
+
+ Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
+ when displaying. The route name is preserved for back-compat with
+ pinned client deployments.
+ """
+ used = user.premium_credit_micros_used
+ limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
- premium_tokens_used=used,
- premium_tokens_limit=limit,
- premium_tokens_remaining=max(0, limit - used),
+ premium_credit_micros_used=used,
+ premium_credit_micros_limit=limit,
+ premium_credit_micros_remaining=max(0, limit - used),
)
diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py
index 315c7c9fe..4f7e9f725 100644
--- a/surfsense_backend/app/routes/vision_llm_routes.py
+++ b/surfsense_backend/app/routes/vision_llm_routes.py
@@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
+ # Auto mode treated as free until per-deployment billing-tier
+ # surfacing lands; see ``get_vision_llm`` for parity.
+ "billing_tier": "free",
}
)
@@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
+ "quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
+ "input_cost_per_token": cfg.get("input_cost_per_token"),
+ "output_cost_per_token": cfg.get("output_cost_per_token"),
}
)
diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py
index 69f534e20..facca7b86 100644
--- a/surfsense_backend/app/schemas/image_generation.py
+++ b/surfsense_backend/app/schemas/image_generation.py
@@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
+
+ The ``billing_tier`` field allows the frontend to show a Premium/Free
+ badge and (more importantly) tells the backend whether to debit the
+ user's premium credit pool when this config is used. ``"free"`` is
+ the default for backward compatibility — admins must explicitly opt
+ a global config into ``"premium"``.
"""
id: int = Field(
@@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
+ billing_tier: str = Field(
+ default="free",
+ description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
+ )
+ quota_reserve_micros: int | None = Field(
+ default=None,
+ description=(
+ "Optional override for the reservation amount (in micro-USD) used when "
+ "this image generation is premium. Falls back to "
+ "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
+ ),
+ )
diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py
index ec5eefc07..892ff9693 100644
--- a/surfsense_backend/app/schemas/new_chat.py
+++ b/surfsense_backend/app/schemas/new_chat.py
@@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
+ cost_micros: int = 0
model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True)
diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py
index 3edd3e9e4..57265ec8e 100644
--- a/surfsense_backend/app/schemas/stripe.py
+++ b/surfsense_backend/app/schemas/stripe.py
@@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
class TokenPurchaseRead(BaseModel):
- """Serialized premium token purchase record."""
+ """Serialized premium credit purchase record.
+
+ ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
+ schema name kept ``Token`` for API back-compat with pinned clients.
+ """
id: uuid.UUID
stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None
quantity: int
- tokens_granted: int
+ credit_micros_granted: int
amount_total: int | None = None
currency: str | None = None
status: str
@@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
class TokenPurchaseHistoryResponse(BaseModel):
- """Response containing the user's premium token purchases."""
+ """Response containing the user's premium credit purchases."""
purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel):
- """Response describing token-buying availability and current quota."""
+ """Response describing premium-credit-buying availability and balance.
+
+ All ``premium_credit_micros_*`` fields are in micro-USD; the FE
+ divides by 1_000_000 to display USD.
+ """
token_buying_enabled: bool
- premium_tokens_used: int = 0
- premium_tokens_limit: int = 0
- premium_tokens_remaining: int = 0
+ premium_credit_micros_used: int = 0
+ premium_credit_micros_limit: int = 0
+ premium_credit_micros_remaining: int = 0
diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py
index ab2e609dc..e55333a9d 100644
--- a/surfsense_backend/app/schemas/vision_llm.py
+++ b/surfsense_backend/app/schemas/vision_llm.py
@@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
class GlobalVisionLLMConfigRead(BaseModel):
+ """Schema for reading global vision LLM configs from YAML.
+
+ The ``billing_tier`` field allows the frontend to show a Premium/Free
+ badge and (more importantly) tells the backend whether to debit the
+ user's premium credit pool when this config is used. ``"free"`` is
+ the default for backward compatibility — admins must explicitly opt
+ a global config into ``"premium"``.
+ """
+
id: int = Field(...)
name: str
description: str | None = None
@@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
+ billing_tier: str = Field(
+ default="free",
+ description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
+ )
+ quota_reserve_tokens: int | None = Field(
+ default=None,
+ description=(
+ "Optional override for the per-call reservation in *tokens* — "
+ "converted to micro-USD via the model's input/output prices at "
+ "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
+ ),
+ )
+ input_cost_per_token: float | None = Field(
+ default=None,
+ description=(
+ "Optional input price in USD/token. Used by pricing_registration to "
+ "register custom Azure / OpenRouter aliases with LiteLLM at startup."
+ ),
+ )
+ output_cost_per_token: float | None = Field(
+ default=None,
+ description="Optional output price in USD/token. Pair with input_cost_per_token.",
+ )
diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py
new file mode 100644
index 000000000..f5ca9818e
--- /dev/null
+++ b/surfsense_backend/app/services/billable_calls.py
@@ -0,0 +1,430 @@
+"""
+Per-call billable wrapper for image generation, vision LLM extraction, and
+any other short-lived premium operation that must charge against the user's
+shared premium credit pool.
+
+The ``billable_call`` async context manager encapsulates the standard
+"reserve → execute → finalize / release → record audit row" lifecycle in a
+single primitive so callers (the image-generation REST route and the
+vision-LLM wrapper used during indexing) don't have to re-implement it.
+
+KEY DESIGN POINTS (issue A, B):
+
+1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
+ argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
+ insert each run inside their own ``shielded_async_session()``. This
+ guarantees that a quota commit/rollback can never accidentally flush or
+ roll back rows the caller has staged in the request's main session
+ (e.g. a freshly-created ``ImageGeneration`` row).
+
+2. **ContextVar safety.** The accumulator is scoped via
+ :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
+ nested ``billable_call`` inside an outer chat turn cannot corrupt the
+ chat turn's accumulator.
+
+3. **Free configs are still audited.** Free calls bypass the reserve /
+ finalize dance entirely but still record a ``TokenUsage`` audit row with
+ the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
+ pipeline complete for analytics even when nothing is debited.
+
+4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
+ responsible for translating that into HTTP 402. We *do not* catch the
+ denial inside ``billable_call`` — letting it propagate also prevents
+ the image-generation route from creating an ``ImageGeneration`` row
+ for a request that never actually ran.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
+from typing import Any
+from uuid import UUID, uuid4
+
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.config import config
+from app.db import shielded_async_session
+from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+)
+from app.services.token_tracking_service import (
+ TurnTokenAccumulator,
+ record_token_usage,
+ scoped_turn,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class QuotaInsufficientError(Exception):
+ """Raised when ``TokenQuotaService.premium_reserve`` denies a billable
+ call because the user has exhausted their premium credit pool.
+
+ The route handler should catch this and return HTTP 402 Payment
+ Required (or the equivalent for the surface area). Outside of the HTTP
+ layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
+ callers may catch this and degrade gracefully — e.g. fall back to OCR
+ when vision is unavailable.
+ """
+
+ def __init__(
+ self,
+ *,
+ usage_type: str,
+ used_micros: int,
+ limit_micros: int,
+ remaining_micros: int,
+ ) -> None:
+ self.usage_type = usage_type
+ self.used_micros = used_micros
+ self.limit_micros = limit_micros
+ self.remaining_micros = remaining_micros
+ super().__init__(
+ f"Premium credit exhausted for {usage_type}: "
+ f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
+ )
+
+
+@asynccontextmanager
+async def billable_call(
+ *,
+ user_id: UUID,
+ search_space_id: int,
+ billing_tier: str,
+ base_model: str,
+ quota_reserve_tokens: int | None = None,
+ quota_reserve_micros_override: int | None = None,
+ usage_type: str,
+ thread_id: int | None = None,
+ message_id: int | None = None,
+ call_details: dict[str, Any] | None = None,
+) -> AsyncIterator[TurnTokenAccumulator]:
+ """Wrap a single billable LLM/image call.
+
+ Args:
+ user_id: Owner of the credit pool to debit. For vision-LLM during
+ indexing this is the *search-space owner* (issue M), not the
+ triggering user.
+ search_space_id: Required — recorded on the ``TokenUsage`` audit row.
+ billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
+ the reserve/finalize dance but still records an audit row with
+ the captured cost.
+ base_model: Used by :func:`estimate_call_reserve_micros` to compute
+ a worst-case reservation from LiteLLM's pricing table.
+ quota_reserve_tokens: Optional per-config override for the chat-style
+ reserve estimator (vision LLM uses this).
+ quota_reserve_micros_override: Optional flat micro-USD reservation
+ (image generation uses this — its cost shape is per-image, not
+ per-token).
+ usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
+ Recorded on the ``TokenUsage`` row.
+ thread_id, message_id: Optional FK columns on ``TokenUsage``.
+ call_details: Optional per-call metadata (model name, parameters)
+ forwarded to ``record_token_usage``.
+
+ Yields:
+ The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
+ the underlying LLM/image API while inside the ``async with``; the
+ ``TokenTrackingCallback`` populates the accumulator automatically.
+
+ Raises:
+ QuotaInsufficientError: when premium and ``premium_reserve`` denies.
+ """
+ is_premium = billing_tier == "premium"
+
+ async with scoped_turn() as acc:
+ # ---------- Free path: just audit -------------------------------
+ if not is_premium:
+ try:
+ yield acc
+ finally:
+ # Always audit, even on exception, so we capture cost when
+ # provider returns successfully but the caller raises later.
+ try:
+ async with shielded_async_session() as audit_session:
+ await record_token_usage(
+ audit_session,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=acc.total_cost_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ )
+ await audit_session.commit()
+ except Exception:
+ logger.exception(
+ "[billable_call] free-path audit insert failed for "
+ "usage_type=%s user_id=%s",
+ usage_type,
+ user_id,
+ )
+ return
+
+ # ---------- Premium path: reserve → execute → finalize ----------
+ if quota_reserve_micros_override is not None:
+ reserve_micros = max(1, int(quota_reserve_micros_override))
+ else:
+ reserve_micros = estimate_call_reserve_micros(
+ base_model=base_model or "",
+ quota_reserve_tokens=quota_reserve_tokens,
+ )
+
+ request_id = str(uuid4())
+
+ async with shielded_async_session() as quota_session:
+ reserve_result = await TokenQuotaService.premium_reserve(
+ db_session=quota_session,
+ user_id=user_id,
+ request_id=request_id,
+ reserve_micros=reserve_micros,
+ )
+
+ if not reserve_result.allowed:
+ logger.info(
+ "[billable_call] reserve DENIED user=%s usage_type=%s "
+ "reserve=%d used=%d limit=%d remaining=%d",
+ user_id,
+ usage_type,
+ reserve_micros,
+ reserve_result.used,
+ reserve_result.limit,
+ reserve_result.remaining,
+ )
+ raise QuotaInsufficientError(
+ usage_type=usage_type,
+ used_micros=reserve_result.used,
+ limit_micros=reserve_result.limit,
+ remaining_micros=reserve_result.remaining,
+ )
+
+ logger.info(
+ "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
+ "(remaining=%d)",
+ user_id,
+ usage_type,
+ reserve_micros,
+ reserve_result.remaining,
+ )
+
+ try:
+ yield acc
+ except BaseException:
+ # Release on any failure (including QuotaInsufficientError raised
+ # from a downstream call, asyncio cancellation, etc.). We use
+ # BaseException so cancellation also releases.
+ try:
+ async with shielded_async_session() as quota_session:
+ await TokenQuotaService.premium_release(
+ db_session=quota_session,
+ user_id=user_id,
+ reserved_micros=reserve_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] premium_release failed for user=%s "
+ "reserve_micros=%d (reservation will be GC'd by quota "
+ "reconciliation if/when implemented)",
+ user_id,
+ reserve_micros,
+ )
+ raise
+
+ # ---------- Success: finalize + audit ----------------------------
+ actual_micros = acc.total_cost_micros
+ try:
+ async with shielded_async_session() as quota_session:
+ final_result = await TokenQuotaService.premium_finalize(
+ db_session=quota_session,
+ user_id=user_id,
+ request_id=request_id,
+ actual_micros=actual_micros,
+ reserved_micros=reserve_micros,
+ )
+ logger.info(
+ "[billable_call] finalize user=%s usage_type=%s actual=%d "
+ "reserved=%d → used=%d/%d (remaining=%d)",
+ user_id,
+ usage_type,
+ actual_micros,
+ reserve_micros,
+ final_result.used,
+ final_result.limit,
+ final_result.remaining,
+ )
+ except Exception:
+ # Last-ditch: if finalize itself fails, we must at least release
+ # so the reservation doesn't leak.
+ logger.exception(
+ "[billable_call] premium_finalize failed for user=%s; "
+ "attempting release",
+ user_id,
+ )
+ try:
+ async with shielded_async_session() as quota_session:
+ await TokenQuotaService.premium_release(
+ db_session=quota_session,
+ user_id=user_id,
+ reserved_micros=reserve_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] release after finalize failure ALSO failed "
+ "for user=%s",
+ user_id,
+ )
+
+ try:
+ async with shielded_async_session() as audit_session:
+ await record_token_usage(
+ audit_session,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=actual_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ )
+ await audit_session.commit()
+ except Exception:
+ logger.exception(
+ "[billable_call] premium-path audit insert failed for "
+ "usage_type=%s user_id=%s (debit was applied)",
+ usage_type,
+ user_id,
+ )
+
+
+async def _resolve_agent_billing_for_search_space(
+ session: AsyncSession,
+ search_space_id: int,
+ *,
+ thread_id: int | None = None,
+) -> tuple[UUID, str, str]:
+ """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
+ agent LLM.
+
+ Used by Celery tasks (podcast generation, video presentation) to bill the
+ search-space owner's premium credit pool when the agent LLM is premium.
+
+ Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
+
+ - Search space not found / no ``agent_llm_id``: raise ``ValueError``.
+ - **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
+ * ``thread_id`` is set: delegate to
+ ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
+ recurse into the resolved id. Reuses chat's existing pin if present
+ so the same model bills for chat + downstream podcast/video. If the
+ user is not premium-eligible, the pin service auto-restricts to free
+ deployments — denial only happens later in
+ ``billable_call.premium_reserve`` if the pin really is premium and
+ credit ran out mid-flow.
+ * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
+ for any future direct-API path; today both Celery tasks always pass
+ ``thread_id``.
+ - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
+ (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
+ ``base_model = litellm_params.get("base_model") or model_name`` —
+ NOT provider-prefixed, matching chat's cost-map lookup convention.
+ - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
+ ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
+ ``base_model`` from ``litellm_params`` or ``model_name``.
+
+ Note on imports: ``llm_service``, ``auto_model_pin_service``, and
+ ``llm_router_service`` are imported lazily inside the function body to
+ avoid hoisting litellm side-effects (``litellm.callbacks =
+ [token_tracker]``, ``litellm.drop_params``, etc.) into
+ ``billable_calls.py``'s module load path.
+ """
+ from sqlalchemy import select
+
+ from app.db import NewLLMConfig, SearchSpace
+
+ result = await session.execute(
+ select(SearchSpace).where(SearchSpace.id == search_space_id)
+ )
+ search_space = result.scalars().first()
+ if search_space is None:
+ raise ValueError(f"Search space {search_space_id} not found")
+
+ agent_llm_id = search_space.agent_llm_id
+ if agent_llm_id is None:
+ raise ValueError(
+ f"Search space {search_space_id} has no agent_llm_id configured"
+ )
+
+ owner_user_id: UUID = search_space.user_id
+
+ from app.services.auto_model_pin_service import (
+ AUTO_FASTEST_ID,
+ resolve_or_get_pinned_llm_config_id,
+ )
+
+ if agent_llm_id == AUTO_FASTEST_ID:
+ if thread_id is None:
+ return owner_user_id, "free", "auto"
+ try:
+ resolution = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=str(owner_user_id),
+ selected_llm_config_id=AUTO_FASTEST_ID,
+ )
+ except ValueError:
+ logger.warning(
+ "[agent_billing] Auto-mode pin resolution failed for "
+ "search_space=%s thread=%s; falling back to free",
+ search_space_id,
+ thread_id,
+ exc_info=True,
+ )
+ return owner_user_id, "free", "auto"
+ agent_llm_id = resolution.resolved_llm_config_id
+
+ if agent_llm_id < 0:
+ from app.services.llm_service import get_global_llm_config
+
+ cfg = get_global_llm_config(agent_llm_id) or {}
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
+ return owner_user_id, billing_tier, base_model
+
+ nlc_result = await session.execute(
+ select(NewLLMConfig).where(
+ NewLLMConfig.id == agent_llm_id,
+ NewLLMConfig.search_space_id == search_space_id,
+ )
+ )
+ nlc = nlc_result.scalars().first()
+ base_model = ""
+ if nlc is not None:
+ litellm_params = nlc.litellm_params or {}
+ base_model = litellm_params.get("base_model") or nlc.model_name or ""
+ return owner_user_id, "free", base_model
+
+
+__all__ = [
+ "QuotaInsufficientError",
+ "_resolve_agent_billing_for_search_space",
+ "billable_call",
+]
+
+
+# Re-export the config knob so callers don't have to import config just for
+# the default image reserve.
+DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py
index 8a7b2919a..1e9d235c8 100644
--- a/surfsense_backend/app/services/llm_router_service.py
+++ b/surfsense_backend/app/services/llm_router_service.py
@@ -134,42 +134,16 @@ PROVIDER_MAP = {
}
-# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
-# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
-# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
-# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
-# request to an Azure endpoint, which then 404s with ``Resource not found``.
-# Only providers with a well-known, stable public base URL are listed here —
-# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
-# huggingface, databricks, cloudflare, replicate) are intentionally omitted
-# so their existing config-driven behaviour is preserved.
-PROVIDER_DEFAULT_API_BASE = {
- "openrouter": "https://openrouter.ai/api/v1",
- "groq": "https://api.groq.com/openai/v1",
- "mistral": "https://api.mistral.ai/v1",
- "perplexity": "https://api.perplexity.ai",
- "xai": "https://api.x.ai/v1",
- "cerebras": "https://api.cerebras.ai/v1",
- "deepinfra": "https://api.deepinfra.com/v1/openai",
- "fireworks_ai": "https://api.fireworks.ai/inference/v1",
- "together_ai": "https://api.together.xyz/v1",
- "anyscale": "https://api.endpoints.anyscale.com/v1",
- "cometapi": "https://api.cometapi.com/v1",
- "sambanova": "https://api.sambanova.ai/v1",
-}
-
-
-# Canonical provider → base URL when a config uses a generic ``openai``-style
-# prefix but the ``provider`` field tells us which API it really is
-# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
-# each has its own base URL).
-PROVIDER_KEY_DEFAULT_API_BASE = {
- "DEEPSEEK": "https://api.deepseek.com/v1",
- "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
- "MOONSHOT": "https://api.moonshot.ai/v1",
- "ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
- "MINIMAX": "https://api.minimax.io/v1",
-}
+# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
+# hoisted to ``app.services.provider_api_base`` so vision and image-gen
+# call sites can share the exact same defense (OpenRouter / Groq / etc.
+# 404-ing against an inherited Azure endpoint). Re-exported here for
+# backward compatibility with any external import.
+from app.services.provider_api_base import ( # noqa: E402
+ PROVIDER_DEFAULT_API_BASE,
+ PROVIDER_KEY_DEFAULT_API_BASE,
+ resolve_api_base,
+)
class LLMRouterService:
@@ -466,14 +440,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
- # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
+ # requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
- api_base = config.get("api_base")
- if not api_base:
- api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
- if not api_base:
- api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
if api_base:
litellm_params["api_base"] = api_base
diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py
index 942a9b7af..72c10035d 100644
--- a/surfsense_backend/app/services/llm_service.py
+++ b/surfsense_backend/app/services/llm_service.py
@@ -496,8 +496,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
+
+ Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
+ so each ``ainvoke`` debits the search-space owner's premium credit
+ pool. User-owned BYOK configs and free global configs are returned
+ unwrapped — they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
@@ -519,6 +525,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
+ owner_user_id = search_space.user_id
+
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@@ -526,6 +534,13 @@ async def get_vision_llm(
)
return None
try:
+ # Auto mode is currently treated as free at the wrapper
+ # level — the underlying router can dispatch to either
+ # premium or free YAML configs but routing decisions are
+ # opaque. If/when we want to bill Auto-routed vision
+ # calls we'd need to thread the resolved deployment's
+ # billing_tier back from the router. For now we keep
+ # parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
@@ -562,8 +577,21 @@ async def get_vision_llm(
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
- return SanitizedChatLiteLLM(**litellm_kwargs)
+ inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
+ billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
+ if billing_tier == "premium":
+ return QuotaCheckedVisionLLM(
+ inner_llm,
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=model_string,
+ quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
+ )
+ return inner_llm
+
+ # User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,
diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py
index 7e856d015..0d030f04f 100644
--- a/surfsense_backend/app/services/openrouter_integration_service.py
+++ b/surfsense_backend/app/services/openrouter_integration_service.py
@@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
+def _is_image_output_model(model: dict) -> bool:
+ """Return True if the model can produce image output.
+
+ OpenRouter's ``architecture.output_modalities`` is a list (e.g.
+ ``["image"]`` for pure image generators, ``["text", "image"]`` for
+ multi-modal generators that also emit captions). We accept any model
+ that can output images; the call site decides whether to use the
+ image-generation API or chat completion.
+ """
+ output_mods = model.get("architecture", {}).get("output_modalities", []) or []
+ return "image" in output_mods
+
+
+def _is_vision_input_model(model: dict) -> bool:
+ """Return True if the model can ingest an image AND emit text.
+
+ OpenRouter's ``architecture.input_modalities`` lists what the model
+ accepts; ``output_modalities`` lists what it produces. A vision LLM
+ is a model that takes images in and produces text out — i.e. it can
+ answer questions about a screenshot or extract content from an
+ image. Pure image-to-image models (e.g. style transfer) and
+ text-only models are excluded.
+ """
+ arch = model.get("architecture", {}) or {}
+ input_mods = arch.get("input_modalities", []) or []
+ output_mods = arch.get("output_modalities", []) or []
+ return "image" in input_mods and "text" in output_mods
+
+
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
return None
+def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
+ """Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
+
+ Pricing values are kept as the raw OpenRouter strings (e.g.
+ ``"0.000003"``); ``pricing_registration`` converts them to floats
+ when registering with LiteLLM. Models with missing or malformed
+ pricing are simply omitted — operator-side risk if any of those are
+ premium.
+ """
+ pricing: dict[str, dict[str, str]] = {}
+ for model in raw_models:
+ model_id = str(model.get("id") or "").strip()
+ if not model_id:
+ continue
+ p = model.get("pricing") or {}
+ prompt = p.get("prompt")
+ completion = p.get("completion")
+ if prompt is None and completion is None:
+ continue
+ pricing[model_id] = {
+ "prompt": str(prompt) if prompt is not None else "",
+ "completion": str(completion) if completion is not None else "",
+ }
+ return pricing
+
+
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
@@ -282,6 +337,162 @@ def _generate_configs(
return configs
+# ID-offset bands used to keep dynamic OpenRouter configs in their own
+# namespace per surface. Image / vision get separate bands so a single
+# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
+_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
+_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
+
+
+def _generate_image_gen_configs(
+ raw_models: list[dict], settings: dict[str, Any]
+) -> list[dict]:
+ """Convert OpenRouter image-generation models into global image-gen
+ config dicts (matches the YAML shape consumed by ``image_generation_routes``).
+
+ Filter:
+ - architecture.output_modalities contains "image"
+ - compatible provider (excluded slugs blocked)
+ - allowed model id (excluded list blocked)
+
+ Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
+ ``_has_sufficient_context``) because tool calls and context windows are
+ irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
+ derived per model the same way as chat (``_openrouter_tier``).
+
+ Cost is intentionally *not* registered with LiteLLM at startup
+ (``pricing_registration`` skips image gen): OpenRouter image-gen
+ models are not in LiteLLM's native cost map and OpenRouter populates
+ ``response_cost`` directly from the response header. A defensive
+ branch in ``_extract_cost_usd`` handles the rare case where
+ ``usage.cost`` is missing — see ``token_tracking_service``.
+ """
+ id_offset: int = int(
+ settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
+ )
+ api_key: str = settings.get("api_key", "")
+ rpm: int = settings.get("rpm", 200)
+ free_rpm: int = settings.get("free_rpm", 20)
+ litellm_params: dict = settings.get("litellm_params") or {}
+
+ image_models = [
+ m
+ for m in raw_models
+ if _is_image_output_model(m)
+ and _is_compatible_provider(m)
+ and _is_allowed_model(m)
+ and "/" in m.get("id", "")
+ ]
+
+ configs: list[dict] = []
+ taken: set[int] = set()
+ for model in image_models:
+ model_id: str = model["id"]
+ name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+
+ cfg: dict[str, Any] = {
+ "id": _stable_config_id(model_id, id_offset, taken),
+ "name": name,
+ "description": f"{name} via OpenRouter (image generation)",
+ "provider": "OPENROUTER",
+ "model_name": model_id,
+ "api_key": api_key,
+ "api_base": "",
+ "api_version": None,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "litellm_params": dict(litellm_params),
+ "billing_tier": tier,
+ _OPENROUTER_DYNAMIC_MARKER: True,
+ }
+ configs.append(cfg)
+
+ return configs
+
+
+def _generate_vision_llm_configs(
+ raw_models: list[dict], settings: dict[str, Any]
+) -> list[dict]:
+ """Convert OpenRouter vision-capable LLMs into global vision-LLM config
+ dicts (matches the YAML shape consumed by ``vision_llm_routes``).
+
+ Filter:
+ - architecture.input_modalities contains "image"
+ - architecture.output_modalities contains "text"
+ - compatible provider (excluded slugs blocked)
+ - allowed model id (excluded list blocked)
+
+ Vision-LLM is invoked from the indexer (image extraction during
+ document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
+ the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
+ filters do not apply: a small-context vision model that doesn't
+ advertise tool-calling is still perfectly viable for "describe this
+ image" prompts.
+ """
+ id_offset: int = int(
+ settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
+ )
+ api_key: str = settings.get("api_key", "")
+ rpm: int = settings.get("rpm", 200)
+ tpm: int = settings.get("tpm", 1_000_000)
+ free_rpm: int = settings.get("free_rpm", 20)
+ free_tpm: int = settings.get("free_tpm", 100_000)
+ quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
+ litellm_params: dict = settings.get("litellm_params") or {}
+
+ vision_models = [
+ m
+ for m in raw_models
+ if _is_vision_input_model(m)
+ and _is_compatible_provider(m)
+ and _is_allowed_model(m)
+ and "/" in m.get("id", "")
+ ]
+
+ configs: list[dict] = []
+ taken: set[int] = set()
+ for model in vision_models:
+ model_id: str = model["id"]
+ name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+ pricing = model.get("pricing") or {}
+
+ # Capture per-token prices so ``pricing_registration`` can
+ # register them with LiteLLM at startup (and so the cost
+ # estimator in ``estimate_call_reserve_micros`` can resolve
+ # them at reserve time).
+ try:
+ input_cost = float(pricing.get("prompt", 0) or 0)
+ except (TypeError, ValueError):
+ input_cost = 0.0
+ try:
+ output_cost = float(pricing.get("completion", 0) or 0)
+ except (TypeError, ValueError):
+ output_cost = 0.0
+
+ cfg: dict[str, Any] = {
+ "id": _stable_config_id(model_id, id_offset, taken),
+ "name": name,
+ "description": f"{name} via OpenRouter (vision)",
+ "provider": "OPENROUTER",
+ "model_name": model_id,
+ "api_key": api_key,
+ "api_base": "",
+ "api_version": None,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "tpm": free_tpm if tier == "free" else tpm,
+ "litellm_params": dict(litellm_params),
+ "billing_tier": tier,
+ "quota_reserve_tokens": quota_reserve_tokens,
+ "input_cost_per_token": input_cost or None,
+ "output_cost_per_token": output_cost or None,
+ _OPENROUTER_DYNAMIC_MARKER: True,
+ }
+ configs.append(cfg)
+
+ return configs
+
+
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
@@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
+ # Raw OpenRouter pricing per model_id, captured at the same time
+ # we generate configs. Consumed by ``pricing_registration`` to
+ # teach LiteLLM the per-token cost of every dynamic deployment so
+ # the success-callback can populate ``response_cost`` correctly.
+ self._raw_pricing: dict[str, dict[str, str]] = {}
+ # Cached raw catalogue from the most recent fetch. Image / vision
+ # emitters reuse this to avoid a second network call per surface.
+ self._raw_models: list[dict] = []
+ # Image / vision config caches (only populated when the matching
+ # opt-in flag is true on initialize). Refreshed in lockstep with
+ # the chat catalogue.
+ self._image_configs: list[dict] = []
+ self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
self._initialized = True
return []
+ self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
+ self._raw_pricing = _extract_raw_pricing(raw_models)
+
+ # Populate image / vision caches when their opt-in flag is set.
+ # Empty otherwise so the accessors return [] without re-running
+ # filters every refresh.
+ if settings.get("image_generation_enabled"):
+ self._image_configs = _generate_image_gen_configs(raw_models, settings)
+ logger.info(
+ "OpenRouter integration: image-gen emission ON (%d models)",
+ len(self._image_configs),
+ )
+ else:
+ self._image_configs = []
+
+ if settings.get("vision_enabled"):
+ self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
+ logger.info(
+ "OpenRouter integration: vision LLM emission ON (%d models)",
+ len(self._vision_configs),
+ )
+ else:
+ self._vision_configs = []
+
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
+ self._raw_pricing = _extract_raw_pricing(raw_models)
+ self._raw_models = raw_models
from app.config import config as app_config
@@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
+ # Image / vision lists are atomic-swapped the same way: filter out
+ # the previous dynamic entries from the live config list and append
+ # the freshly generated ones. No-ops when the opt-in flag is off.
+ if self._settings.get("image_generation_enabled"):
+ new_image = _generate_image_gen_configs(raw_models, self._settings)
+ static_image = [
+ c
+ for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
+ if not c.get(_OPENROUTER_DYNAMIC_MARKER)
+ ]
+ app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
+ self._image_configs = new_image
+
+ if self._settings.get("vision_enabled"):
+ new_vision = _generate_vision_llm_configs(raw_models, self._settings)
+ static_vision = [
+ c
+ for c in app_config.GLOBAL_VISION_LLM_CONFIGS
+ if not c.get(_OPENROUTER_DYNAMIC_MARKER)
+ ]
+ app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
+ self._vision_configs = new_vision
+
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
+ # Re-register LiteLLM pricing for the freshly fetched catalogue
+ # so newly added OR models bill correctly on their first call.
+ # Runs before the router rebuild because the router may issue
+ # cost-table lookups during deployment registration.
+ try:
+ from app.services.pricing_registration import (
+ register_pricing_from_global_configs,
+ )
+
+ register_pricing_from_global_configs()
+ except Exception as exc:
+ logger.warning(
+ "OpenRouter refresh: pricing re-registration skipped (%s)", exc
+ )
+
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
@@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)
+
+ def get_image_generation_configs(self) -> list[dict]:
+ """Return the dynamic OpenRouter image-generation configs (empty
+ list when the ``image_generation_enabled`` flag is off).
+
+ Each entry already has ``billing_tier`` derived per-model from
+ OpenRouter's signals and is shaped to drop directly into
+ ``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
+ """
+ return list(self._image_configs)
+
+ def get_vision_llm_configs(self) -> list[dict]:
+ """Return the dynamic OpenRouter vision-LLM configs (empty list
+ when the ``vision_enabled`` flag is off).
+
+ Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
+ so ``pricing_registration`` can teach LiteLLM the cost of these
+ models the same way it does for chat — which keeps the billable
+ wrapper able to debit accurate micro-USD on a vision call.
+ """
+ return list(self._vision_configs)
+
+ def get_raw_pricing(self) -> dict[str, dict[str, str]]:
+ """Return the cached raw OpenRouter pricing map.
+
+ Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
+ values are the strings OpenRouter publishes (USD per token),
+ never converted to floats here so the caller can decide how to
+ handle malformed or unset entries.
+ """
+ return dict(self._raw_pricing)
diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py
new file mode 100644
index 000000000..de98e50c2
--- /dev/null
+++ b/surfsense_backend/app/services/pricing_registration.py
@@ -0,0 +1,274 @@
+"""
+Pricing registration with LiteLLM.
+
+Many models reach our LiteLLM callback without LiteLLM knowing their
+per-token cost — namely:
+
+* The ~300 dynamic OpenRouter deployments (their pricing only lives on
+ OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
+ pricing table).
+* Static YAML deployments whose ``base_model`` name is operator-defined
+ (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
+ not in LiteLLM's table either.
+
+Without registration, ``kwargs["response_cost"]`` is 0 for those calls
+and the user gets billed nothing — a fail-safe but wrong answer for a
+cost-based credit system. This module runs once at startup, after the
+OpenRouter integration has fetched its catalogue, and registers each
+known model's pricing with ``litellm.register_model()`` under multiple
+plausible alias keys (LiteLLM's cost lookup may use any of them
+depending on whether the call went through the Router, ChatLiteLLM,
+or a direct ``acompletion``).
+
+Operators who run a custom Azure deployment whose ``base_model`` name
+isn't in LiteLLM's table can declare per-token pricing inline in
+``global_llm_config.yaml`` via ``input_cost_per_token`` and
+``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
+that declaration the model's calls debit 0 — never overbilled.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+import litellm
+
+logger = logging.getLogger(__name__)
+
+
+def _safe_float(value: Any) -> float:
+ """Return ``float(value)`` if it parses to a positive number, else 0.0."""
+ if value is None:
+ return 0.0
+ try:
+ f = float(value)
+ except (TypeError, ValueError):
+ return 0.0
+ return f if f > 0 else 0.0
+
+
+def _alias_set_for_openrouter(model_id: str) -> list[str]:
+ """Return the alias keys to register an OpenRouter model under.
+
+ LiteLLM's cost-callback lookup key varies by call path:
+ - Router with ``model="openrouter/X"`` → kwargs["model"] is
+ typically ``openrouter/X``.
+ - LiteLLM's own provider routing may strip the prefix and pass the
+ bare ``X`` to the cost-table lookup.
+ Registering under both keeps the lookup hermetic regardless of
+ which path the call took.
+ """
+ aliases = [f"openrouter/{model_id}", model_id]
+ return list(dict.fromkeys(a for a in aliases if a))
+
+
+def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
+ """Return the alias keys to register a static YAML deployment under.
+
+ Same reasoning as the OpenRouter set: cover the bare ``base_model``,
+ the ``/`` form LiteLLM Router constructs, and the
+ bare ``model_name`` because callbacks sometimes see whichever was
+ configured first.
+ """
+ provider_lower = (provider or "").lower()
+ aliases: list[str] = []
+ if base_model:
+ aliases.append(base_model)
+ if provider_lower and base_model:
+ aliases.append(f"{provider_lower}/{base_model}")
+ if model_name and model_name != base_model:
+ aliases.append(model_name)
+ if provider_lower and model_name and model_name != base_model:
+ aliases.append(f"{provider_lower}/{model_name}")
+ # Azure deployments often surface as "azure/"; normalise the
+ # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
+ if provider_lower == "azure_openai":
+ if base_model:
+ aliases.append(f"azure/{base_model}")
+ if model_name and model_name != base_model:
+ aliases.append(f"azure/{model_name}")
+ return list(dict.fromkeys(a for a in aliases if a))
+
+
+def _register(
+ aliases: list[str],
+ *,
+ input_cost: float,
+ output_cost: float,
+ provider: str,
+ mode: str = "chat",
+) -> int:
+ """Register a single pricing entry under every alias in ``aliases``.
+
+ Returns the count of aliases successfully registered.
+ """
+ payload: dict[str, dict[str, Any]] = {}
+ for alias in aliases:
+ payload[alias] = {
+ "input_cost_per_token": input_cost,
+ "output_cost_per_token": output_cost,
+ "litellm_provider": provider,
+ "mode": mode,
+ }
+ if not payload:
+ return 0
+ try:
+ litellm.register_model(payload)
+ except Exception as exc:
+ logger.warning(
+ "[PricingRegistration] register_model failed for aliases=%s: %s",
+ aliases,
+ exc,
+ )
+ return 0
+ return len(payload)
+
+
+def _register_chat_shape_configs(
+ configs: list[dict],
+ *,
+ or_pricing: dict[str, dict[str, str]],
+ label: str,
+) -> tuple[int, int, int, list[str]]:
+ """Common loop that registers per-token pricing for a list of "chat-shape"
+ configs (chat or vision LLM — both use ``input_cost_per_token`` /
+ ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
+
+ Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
+ """
+ registered_models = 0
+ registered_aliases = 0
+ skipped_no_pricing = 0
+ sample_keys: list[str] = []
+
+ for cfg in configs:
+ provider = str(cfg.get("provider") or "").upper()
+ model_name = str(cfg.get("model_name") or "").strip()
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = str(litellm_params.get("base_model") or model_name).strip()
+
+ if provider == "OPENROUTER":
+ entry = or_pricing.get(model_name)
+ if entry:
+ input_cost = _safe_float(entry.get("prompt"))
+ output_cost = _safe_float(entry.get("completion"))
+ else:
+ # Vision configs from ``_generate_vision_llm_configs``
+ # carry their pricing inline because the OpenRouter
+ # raw-pricing cache is keyed by chat-catalogue model_id;
+ # vision flows pick up the inline values here.
+ input_cost = _safe_float(cfg.get("input_cost_per_token"))
+ output_cost = _safe_float(cfg.get("output_cost_per_token"))
+ if input_cost == 0.0 and output_cost == 0.0:
+ skipped_no_pricing += 1
+ continue
+ aliases = _alias_set_for_openrouter(model_name)
+ count = _register(
+ aliases,
+ input_cost=input_cost,
+ output_cost=output_cost,
+ provider="openrouter",
+ )
+ if count > 0:
+ registered_models += 1
+ registered_aliases += count
+ if len(sample_keys) < 6:
+ sample_keys.extend(aliases[:2])
+ continue
+
+ input_cost = _safe_float(
+ cfg.get("input_cost_per_token")
+ or litellm_params.get("input_cost_per_token")
+ )
+ output_cost = _safe_float(
+ cfg.get("output_cost_per_token")
+ or litellm_params.get("output_cost_per_token")
+ )
+ if input_cost == 0.0 and output_cost == 0.0:
+ skipped_no_pricing += 1
+ continue
+ aliases = _alias_set_for_yaml(provider, model_name, base_model)
+ provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
+ count = _register(
+ aliases,
+ input_cost=input_cost,
+ output_cost=output_cost,
+ provider=provider_slug,
+ )
+ if count > 0:
+ registered_models += 1
+ registered_aliases += count
+ if len(sample_keys) < 6:
+ sample_keys.extend(aliases[:2])
+
+ logger.info(
+ "[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
+ "%d configs had no pricing data; sample registered keys=%s",
+ label,
+ registered_models,
+ registered_aliases,
+ skipped_no_pricing,
+ sample_keys,
+ )
+ return registered_models, registered_aliases, skipped_no_pricing, sample_keys
+
+
+def register_pricing_from_global_configs() -> None:
+ """Register pricing for every known LLM deployment with LiteLLM.
+
+ Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
+ so vision calls (during indexing) can resolve cost the same way chat
+ calls do — namely:
+
+ 1. ``OPENROUTER``: pulls the cached raw pricing from
+ ``OpenRouterIntegrationService`` (populated during its own
+ startup fetch) and converts the per-token strings to floats. For
+ vision configs that carry pricing inline (``input_cost_per_token`` /
+ ``output_cost_per_token`` set on the cfg itself) we fall back to
+ those values when the OR cache misses the model.
+ 2. Anything else: looks for operator-declared
+ ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
+ config block (top-level or nested under ``litellm_params``).
+
+ **Image generation is intentionally NOT registered here.** The cost
+ shape for image-gen is per-image (``output_cost_per_image``), not
+ per-token, and LiteLLM's ``register_model`` doesn't accept those
+ keys via the chat-cost path. OpenRouter image-gen models populate
+ ``response_cost`` directly from their response header instead, and
+ Azure-native image-gen models are already in LiteLLM's cost map.
+
+ Calls without a resolved pair of costs are skipped, not registered
+ with zeros — operators who forget pricing get a "$0 debit" warning
+ in ``TokenTrackingCallback`` rather than silently overwriting any
+ pricing LiteLLM might know natively.
+ """
+ from app.config import config as app_config
+
+ chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
+ vision_configs: list[dict] = list(
+ getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
+ )
+ if not chat_configs and not vision_configs:
+ logger.info("[PricingRegistration] no global configs to register")
+ return
+
+ or_pricing: dict[str, dict[str, str]] = {}
+ try:
+ from app.services.openrouter_integration_service import (
+ OpenRouterIntegrationService,
+ )
+
+ if OpenRouterIntegrationService.is_initialized():
+ or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
+ except Exception as exc:
+ logger.debug(
+ "[PricingRegistration] OpenRouter pricing not available yet: %s", exc
+ )
+
+ if chat_configs:
+ _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
+ if vision_configs:
+ _register_chat_shape_configs(
+ vision_configs, or_pricing=or_pricing, label="vision"
+ )
diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py
new file mode 100644
index 000000000..979d7d3a1
--- /dev/null
+++ b/surfsense_backend/app/services/provider_api_base.py
@@ -0,0 +1,107 @@
+"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
+
+LiteLLM falls back to the module-global ``litellm.api_base`` when an
+individual call doesn't pass one, which silently inherits provider-agnostic
+env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
+explicit ``api_base``, an ``openrouter/`` request can end up at an
+Azure endpoint and 404 with ``Resource not found`` (real reproducer:
+[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
+``/chat/completions`` to whatever inherited base it gets, regardless of
+provider).
+
+The chat router has had this defense for a while
+(``llm_router_service.py:466-478``). This module hoists the maps + cascade
+into a tiny standalone helper so vision and image-gen can share the same
+source of truth without an inter-service circular import.
+"""
+
+from __future__ import annotations
+
+
+PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
+ "openrouter": "https://openrouter.ai/api/v1",
+ "groq": "https://api.groq.com/openai/v1",
+ "mistral": "https://api.mistral.ai/v1",
+ "perplexity": "https://api.perplexity.ai",
+ "xai": "https://api.x.ai/v1",
+ "cerebras": "https://api.cerebras.ai/v1",
+ "deepinfra": "https://api.deepinfra.com/v1/openai",
+ "fireworks_ai": "https://api.fireworks.ai/inference/v1",
+ "together_ai": "https://api.together.xyz/v1",
+ "anyscale": "https://api.endpoints.anyscale.com/v1",
+ "cometapi": "https://api.cometapi.com/v1",
+ "sambanova": "https://api.sambanova.ai/v1",
+}
+"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
+
+Only providers with a well-known, stable public base URL are listed —
+self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
+huggingface, databricks, cloudflare, replicate) are intentionally omitted
+so their existing config-driven behaviour is preserved."""
+
+
+PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
+ "DEEPSEEK": "https://api.deepseek.com/v1",
+ "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
+ "MOONSHOT": "https://api.moonshot.ai/v1",
+ "ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
+ "MINIMAX": "https://api.minimax.io/v1",
+}
+"""Canonical provider key (uppercase) → base URL.
+
+Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
+config's ``provider`` field tells us which API it actually is (DeepSeek,
+Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
+has its own base URL)."""
+
+
+def resolve_api_base(
+ *,
+ provider: str | None,
+ provider_prefix: str | None,
+ config_api_base: str | None,
+) -> str | None:
+ """Resolve a non-Azure-leaking ``api_base`` for a deployment.
+
+ Cascade (first non-empty wins):
+ 1. The config's own ``api_base`` (whitespace-only treated as missing).
+ 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
+ 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
+ 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
+ provider integration apply its own default (e.g. AzureOpenAI's
+ deployment-derived URL, custom provider's per-deployment URL).
+
+ Args:
+ provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
+ ``"DEEPSEEK"``). Case-insensitive.
+ provider_prefix: The LiteLLM model-string prefix the same call
+ site builds for the model id (e.g. ``"openrouter"``,
+ ``"groq"``). Case-insensitive.
+ config_api_base: ``api_base`` from the global YAML / DB row /
+ OpenRouter dynamic config. Empty / whitespace-only means
+ "missing" — the resolver still applies the cascade.
+
+ Returns:
+ A URL string, or ``None`` if no default applies for this provider.
+ """
+ if config_api_base and config_api_base.strip():
+ return config_api_base
+
+ if provider:
+ key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
+ if key_default:
+ return key_default
+
+ if provider_prefix:
+ prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
+ if prefix_default:
+ return prefix_default
+
+ return None
+
+
+__all__ = [
+ "PROVIDER_DEFAULT_API_BASE",
+ "PROVIDER_KEY_DEFAULT_API_BASE",
+ "resolve_api_base",
+]
diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py
new file mode 100644
index 000000000..0040e5a5b
--- /dev/null
+++ b/surfsense_backend/app/services/quota_checked_vision_llm.py
@@ -0,0 +1,105 @@
+"""
+Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
+
+Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
+indexing pipeline (file processors, connector indexers, etl pipeline) can
+keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
+— without threading ``user_id`` through every parser. The wrapper looks like
+a chat model from the outside; on the inside it routes each call through
+``billable_call`` so the user's premium credit pool is reserved → finalized
+or released, and a ``TokenUsage`` audit row is written.
+
+Free configs are returned unwrapped from ``get_vision_llm`` (they do not
+need quota enforcement) so this class only ever wraps premium configs.
+
+Why a wrapper instead of plumbing ``user_id`` through every caller:
+
+* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
+ Dropbox, local-folder, file-processor, ETL pipeline) each calling
+ ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
+ invasive, error-prone, and easy for a future indexer to forget.
+* Per the design (issue M), we always debit the *search-space owner*, not
+ the triggering user, so ``user_id`` is fully derivable from the search
+ space the caller is already operating on. The wrapper captures it once
+ at construction time.
+* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
+ call run this coroutine"; subclassing isn't safe across versions because
+ it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
+ Composition via attribute proxying (``__getattr__``) is robust to
+ upstream changes — every method other than ``ainvoke`` falls through to
+ the inner LLM unchanged.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+from uuid import UUID
+
+from app.services.billable_calls import QuotaInsufficientError, billable_call
+
+logger = logging.getLogger(__name__)
+
+
+class QuotaCheckedVisionLLM:
+ """Composition wrapper around a langchain chat model that enforces
+ premium credit quota on every ``ainvoke``.
+
+ Anything other than ``ainvoke`` is forwarded to the inner model so
+ ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
+ still work — they simply bypass quota enforcement, which is fine
+ because the indexing pipeline only ever calls ``ainvoke`` today.
+ """
+
+ def __init__(
+ self,
+ inner_llm: Any,
+ *,
+ user_id: UUID,
+ search_space_id: int,
+ billing_tier: str,
+ base_model: str,
+ quota_reserve_tokens: int | None,
+ usage_type: str = "vision_extraction",
+ ) -> None:
+ self._inner = inner_llm
+ self._user_id = user_id
+ self._search_space_id = search_space_id
+ self._billing_tier = billing_tier
+ self._base_model = base_model
+ self._quota_reserve_tokens = quota_reserve_tokens
+ self._usage_type = usage_type
+
+ async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
+ """Proxied async invoke that runs the underlying call inside
+ ``billable_call``.
+
+ Raises:
+ QuotaInsufficientError: when the user has exhausted their
+ premium credit pool. Caller (``etl_pipeline_service._extract_image``)
+ catches this and falls back to the document parser.
+ """
+ async with billable_call(
+ user_id=self._user_id,
+ search_space_id=self._search_space_id,
+ billing_tier=self._billing_tier,
+ base_model=self._base_model,
+ quota_reserve_tokens=self._quota_reserve_tokens,
+ usage_type=self._usage_type,
+ call_details={"model": self._base_model},
+ ):
+ return await self._inner.ainvoke(input, *args, **kwargs)
+
+ def __getattr__(self, name: str) -> Any:
+ """Forward everything else (``invoke``, ``astream``, ``bind``,
+ ``with_structured_output``, …) to the inner model.
+
+ ``__getattr__`` is only consulted when the attribute is *not*
+ already found on the proxy, which is exactly the contract we
+ want — methods we override stay on the proxy, the rest fall
+ through.
+ """
+ return getattr(self._inner, name)
+
+
+__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py
index a3ec7aed0..310c3eb5e 100644
--- a/surfsense_backend/app/services/token_quota_service.py
+++ b/surfsense_backend/app/services/token_quota_service.py
@@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__)
+# ---------------------------------------------------------------------------
+# Per-call reservation estimator (USD micro-units)
+# ---------------------------------------------------------------------------
+
+# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
+# request, and so models without registered pricing reserve at least
+# something while the call runs (debited 0 at finalize anyway when their
+# cost can't be resolved).
+_QUOTA_MIN_RESERVE_MICROS = 100
+
+
+def estimate_call_reserve_micros(
+ *,
+ base_model: str,
+ quota_reserve_tokens: int | None,
+) -> int:
+ """Return the number of micro-USD to reserve for one premium call.
+
+ Computes a worst-case upper bound from LiteLLM's per-token pricing
+ table:
+
+ reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
+
+ so the math scales with model cost — Claude Opus + 4K reserve_tokens
+ naturally reserves ≈ $0.36, while a cheap model reserves only a few
+ cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
+ so a misconfigured "$1000/M" model can't lock the whole balance on
+ one call.
+
+ If ``litellm.get_model_info`` raises (model unknown) we fall back to
+ the floor — 100 micros / $0.0001 — which is enough to gate a sane
+ request without over-reserving for a model whose pricing the
+ operator hasn't declared yet.
+ """
+ reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
+ if reserve_tokens <= 0:
+ reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
+
+ try:
+ from litellm import get_model_info
+
+ info = get_model_info(base_model) if base_model else {}
+ input_cost = float(info.get("input_cost_per_token") or 0.0)
+ output_cost = float(info.get("output_cost_per_token") or 0.0)
+ except Exception as exc:
+ logger.debug(
+ "[quota_reserve] cost lookup failed for base_model=%s: %s",
+ base_model,
+ exc,
+ )
+ input_cost = 0.0
+ output_cost = 0.0
+
+ if input_cost == 0.0 and output_cost == 0.0:
+ return _QUOTA_MIN_RESERVE_MICROS
+
+ reserve_usd = reserve_tokens * (input_cost + output_cost)
+ reserve_micros = round(reserve_usd * 1_000_000)
+ if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
+ reserve_micros = _QUOTA_MIN_RESERVE_MICROS
+ if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
+ reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
+ return reserve_micros
+
+
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
@@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
- reserve_tokens: int,
+ reserve_micros: int,
) -> QuotaResult:
+ """Reserve ``reserve_micros`` (USD micro-units) from the user's
+ premium credit balance.
+
+ ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
+ all in micro-USD on this code path; callers (chat stream,
+ token-status route, FE display) convert to dollars by dividing
+ by 1_000_000.
+ """
from app.db import User
user = (
@@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0,
)
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
- effective = used + reserved + reserve_tokens
+ effective = used + reserved + reserve_micros
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
@@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining,
)
- user.premium_tokens_reserved = reserved + reserve_tokens
+ user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit()
- new_reserved = reserved + reserve_tokens
+ new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
@@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
- actual_tokens: int,
- reserved_tokens: int,
+ actual_micros: int,
+ reserved_micros: int,
) -> QuotaResult:
+ """Settle the reservation: release ``reserved_micros`` and debit
+ ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
+ """
from app.db import User
user = (
@@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
- user.premium_tokens_reserved = max(
- 0, user.premium_tokens_reserved - reserved_tokens
+ user.premium_credit_micros_reserved = max(
+ 0, user.premium_credit_micros_reserved - reserved_micros
+ )
+ user.premium_credit_micros_used = (
+ user.premium_credit_micros_used + actual_micros
)
- user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
@@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release(
db_session: AsyncSession,
user_id: Any,
- reserved_tokens: int,
+ reserved_micros: int,
) -> None:
+ """Release ``reserved_micros`` previously held by ``premium_reserve``.
+
+ Used when a request fails before finalize (so the reservation
+ doesn't leak credit).
+ """
from app.db import User
user = (
@@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none()
)
if user is not None:
- user.premium_tokens_reserved = max(
- 0, user.premium_tokens_reserved - reserved_tokens
+ user.premium_credit_micros_reserved = max(
+ 0, user.premium_credit_micros_reserved - reserved_micros
)
await db_session.commit()
@@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py
index 9aa8c6e70..9406d9be4 100644
--- a/surfsense_backend/app/services/token_tracking_service.py
+++ b/surfsense_backend/app/services/token_tracking_service.py
@@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses
import logging
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
+import litellm
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
@@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int
completion_tokens: int
total_tokens: int
+ cost_micros: int = 0
+ call_kind: str = "chat"
@dataclass
@@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
+ cost_micros: int = 0,
+ call_kind: str = "chat",
) -> None:
self.calls.append(
TokenCallRecord(
@@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ call_kind=call_kind,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
- """Return token counts grouped by model name."""
+ """Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
- {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
+ {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ "cost_micros": 0,
+ },
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
+ entry["cost_micros"] += c.cost_micros
return by_model
@property
@@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
+ @property
+ def total_cost_micros(self) -> int:
+ """Sum of per-call ``cost_micros`` across the entire turn.
+
+ Used by ``stream_new_chat`` to debit a premium turn's actual
+ provider cost (in micro-USD) from the user's premium credit
+ balance. ``cost_micros`` per call is captured by
+ ``TokenTrackingCallback.async_log_success_event`` from
+ ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
+ with multiple fallback paths so OpenRouter dynamic models and
+ custom Azure deployments still bill correctly when our
+ ``pricing_registration`` ran at startup.
+ """
+ return sum(c.cost_micros for c in self.calls)
+
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
@@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator:
- """Create a fresh accumulator for the current async context and return it."""
+ """Create a fresh accumulator for the current async context and return it.
+
+ NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
+ short-lived per-call billable wrappers (image generation REST endpoint,
+ vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
+ ContextVar reset token to restore the *previous* accumulator on exit and
+ avoids leaking call records across reservations (issue B).
+ """
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
+@asynccontextmanager
+async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
+ """Async context manager that scopes a fresh ``TurnTokenAccumulator``
+ for the duration of the ``async with`` block, then *resets* the
+ ContextVar to its previous value on exit.
+
+ This is the safe primitive for per-call billable operations
+ (image generation, vision LLM extraction, podcasts) that may run
+ inside an outer chat turn or be called sequentially from the same
+ background worker. Using ``ContextVar.set`` without ``reset`` (as
+ :func:`start_turn` does) would leak the inner accumulator into the
+ outer scope, causing the outer chat turn to debit cost twice.
+
+ Usage::
+
+ async with scoped_turn() as acc:
+ await llm.ainvoke(...)
+ # acc.total_cost_micros captures cost from the LiteLLM callback
+ # Outer accumulator (if any) is restored here.
+ """
+ acc = TurnTokenAccumulator()
+ token = _turn_accumulator.set(acc)
+ logger.debug(
+ "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
+ id(acc),
+ token,
+ )
+ try:
+ yield acc
+ finally:
+ _turn_accumulator.reset(token)
+ logger.debug(
+ "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
+ id(acc),
+ len(acc.calls),
+ acc.total_cost_micros,
+ )
+
+
+def _extract_cost_usd(
+ kwargs: dict[str, Any],
+ response_obj: Any,
+ model: str,
+ prompt_tokens: int,
+ completion_tokens: int,
+ is_image: bool = False,
+) -> float:
+ """Best-effort USD cost extraction for a single LLM/image call.
+
+ Tries four sources in priority order and returns the first that
+ yields a positive number; returns 0.0 if all four fail (the call
+ will then debit nothing from the user's balance — fail-safe).
+
+ Sources:
+ 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
+ field, populated for ``Router.acompletion`` since PR #12500.
+ 2. ``response_obj._hidden_params["response_cost"]`` — same value
+ exposed on the response itself.
+ 3. ``litellm.completion_cost(completion_response=response_obj)``
+ — recompute from the response and LiteLLM's pricing table.
+ 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
+ — manual fallback for OpenRouter/custom-Azure models that
+ only resolve via aliases registered by
+ ``pricing_registration`` at startup. **Skipped for image
+ responses** — ``cost_per_token`` does not support ``ImageResponse``
+ and would raise; the cost map for image-gen lives in different
+ keys (``output_cost_per_image``) handled by ``completion_cost``.
+ """
+ cost = kwargs.get("response_cost")
+ if cost is not None:
+ try:
+ value = float(cost)
+ except (TypeError, ValueError):
+ value = 0.0
+ if value > 0:
+ return value
+
+ hidden = getattr(response_obj, "_hidden_params", None) or {}
+ if isinstance(hidden, dict):
+ cost = hidden.get("response_cost")
+ if cost is not None:
+ try:
+ value = float(cost)
+ except (TypeError, ValueError):
+ value = 0.0
+ if value > 0:
+ return value
+
+ try:
+ value = float(litellm.completion_cost(completion_response=response_obj))
+ if value > 0:
+ return value
+ except Exception as exc:
+ if is_image:
+ # Image-gen path: OpenRouter's image responses can omit
+ # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
+ # then *raises* (no cost map for OpenRouter image models).
+ # Bail out with a warning rather than falling through to
+ # cost_per_token (which is also incompatible with ImageResponse).
+ logger.warning(
+ "[TokenTracking] completion_cost failed for image model=%s "
+ "(provider may have omitted usage.cost). Debiting 0. "
+ "Cause: %s",
+ model,
+ exc,
+ )
+ return 0.0
+ logger.debug(
+ "[TokenTracking] completion_cost failed for model=%s: %s", model, exc
+ )
+
+ if is_image:
+ # Never call cost_per_token for ImageResponse — keys mismatch and
+ # the function is documented chat-only.
+ return 0.0
+
+ if model and (prompt_tokens > 0 or completion_tokens > 0):
+ try:
+ prompt_cost, completion_cost = litellm.cost_per_token(
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ value = float(prompt_cost) + float(completion_cost)
+ if value > 0:
+ return value
+ except Exception as exc:
+ logger.debug(
+ "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
+ )
+
+ return 0.0
+
+
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
@@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
)
return
+ # Detect image generation responses — they have a different usage
+ # shape (ImageUsage with input_tokens/output_tokens) and require a
+ # different cost-extraction path. We probe by class name to avoid a
+ # hard import dependency on litellm internals.
+ response_cls = type(response_obj).__name__
+ is_image = response_cls == "ImageResponse"
+
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug(
@@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
)
return
- prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
- completion_tokens = getattr(usage, "completion_tokens", 0) or 0
- total_tokens = getattr(usage, "total_tokens", 0) or 0
+ if is_image:
+ # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
+ # (not prompt_tokens/completion_tokens). Several providers
+ # populate only one or neither (e.g. OpenRouter's gpt-image-1
+ # passes through `input_tokens` from the prompt but no
+ # completion); fall through gracefully to 0.
+ prompt_tokens = getattr(usage, "input_tokens", 0) or 0
+ completion_tokens = getattr(usage, "output_tokens", 0) or 0
+ total_tokens = (
+ getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
+ )
+ call_kind = "image_generation"
+ else:
+ prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
+ completion_tokens = getattr(usage, "completion_tokens", 0) or 0
+ total_tokens = getattr(usage, "total_tokens", 0) or 0
+ call_kind = "chat"
model = kwargs.get("model", "unknown")
+ cost_usd = _extract_cost_usd(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ is_image=is_image,
+ )
+ cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
+
+ if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
+ logger.warning(
+ "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
+ "kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
+ "input_cost_per_token/output_cost_per_token (or rely on response_cost "
+ "for image generation).",
+ model,
+ prompt_tokens,
+ completion_tokens,
+ call_kind,
+ )
+
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ call_kind=call_kind,
)
logger.info(
- "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
+ "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
+ "cost=$%.6f (%d micros) (accumulator now has %d calls)",
model,
+ call_kind,
prompt_tokens,
completion_tokens,
total_tokens,
+ cost_usd,
+ cost_micros,
len(acc.calls),
)
@@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
+ cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
@@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
@@ -194,11 +416,12 @@ async def record_token_usage(
)
session.add(record)
logger.debug(
- "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
+ "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type,
prompt_tokens,
completion_tokens,
total_tokens,
+ cost_micros,
)
return record
except Exception:
diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py
index 0d782ab2b..ed5de921c 100644
--- a/surfsense_backend/app/services/vision_llm_router_service.py
+++ b/surfsense_backend/app/services/vision_llm_router_service.py
@@ -3,6 +3,8 @@ from typing import Any
from litellm import Router
+from app.services.provider_api_base import resolve_api_base
+
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
@@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
+ provider = config.get("provider", "").upper()
if config.get("custom_provider"):
- model_string = f"{config['custom_provider']}/{config['model_name']}"
+ provider_prefix = config["custom_provider"]
+ model_string = f"{provider_prefix}/{config['model_name']}"
else:
- provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"),
}
- if config.get("api_base"):
- litellm_params["api_base"] = config["api_base"]
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
+ if api_base:
+ litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
index 953011ecf..937877473 100644
--- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
@@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
+from app.config import config as app_config
from app.db import Podcast, PodcastStatus
+from app.services.billable_calls import (
+ QuotaInsufficientError,
+ _resolve_agent_billing_for_search_space,
+ billable_call,
+)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@@ -96,6 +102,31 @@ async def _generate_content_podcast(
podcast.status = PodcastStatus.GENERATING
await session.commit()
+ try:
+ (
+ owner_user_id,
+ billing_tier,
+ base_model,
+ ) = await _resolve_agent_billing_for_search_space(
+ session,
+ search_space_id,
+ thread_id=podcast.thread_id,
+ )
+ except ValueError as resolve_err:
+ logger.error(
+ "Podcast %s: cannot resolve billing for search_space=%s: %s",
+ podcast.id,
+ search_space_id,
+ resolve_err,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "billing_resolution_failed",
+ }
+
graph_config = {
"configurable": {
"podcast_title": podcast.title,
@@ -109,9 +140,39 @@ async def _generate_content_podcast(
db_session=session,
)
- graph_result = await podcaster_graph.ainvoke(
- initial_state, config=graph_config
- )
+ try:
+ async with billable_call(
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
+ usage_type="podcast_generation",
+ thread_id=podcast.thread_id,
+ call_details={
+ "podcast_id": podcast.id,
+ "title": podcast.title,
+ },
+ ):
+ graph_result = await podcaster_graph.ainvoke(
+ initial_state, config=graph_config
+ )
+ except QuotaInsufficientError as exc:
+ logger.info(
+ "Podcast %s denied: out of premium credits "
+ "(used=%d/%d remaining=%d)",
+ podcast.id,
+ exc.used_micros,
+ exc.limit_micros,
+ exc.remaining_micros,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "premium_quota_exhausted",
+ }
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
index 7880b385f..4f0c427d9 100644
--- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
@@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app
+from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
+from app.services.billable_calls import (
+ QuotaInsufficientError,
+ _resolve_agent_billing_for_search_space,
+ billable_call,
+)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@@ -97,6 +103,32 @@ async def _generate_video_presentation(
video_pres.status = VideoPresentationStatus.GENERATING
await session.commit()
+ try:
+ (
+ owner_user_id,
+ billing_tier,
+ base_model,
+ ) = await _resolve_agent_billing_for_search_space(
+ session,
+ search_space_id,
+ thread_id=video_pres.thread_id,
+ )
+ except ValueError as resolve_err:
+ logger.error(
+ "VideoPresentation %s: cannot resolve billing for "
+ "search_space=%s: %s",
+ video_pres.id,
+ search_space_id,
+ resolve_err,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "billing_resolution_failed",
+ }
+
graph_config = {
"configurable": {
"video_title": video_pres.title,
@@ -110,9 +142,39 @@ async def _generate_video_presentation(
db_session=session,
)
- graph_result = await video_presentation_graph.ainvoke(
- initial_state, config=graph_config
- )
+ try:
+ async with billable_call(
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
+ usage_type="video_presentation_generation",
+ thread_id=video_pres.thread_id,
+ call_details={
+ "video_presentation_id": video_pres.id,
+ "title": video_pres.title,
+ },
+ ):
+ graph_result = await video_presentation_graph.ainvoke(
+ initial_state, config=graph_config
+ )
+ except QuotaInsufficientError as exc:
+ logger.info(
+ "VideoPresentation %s denied: out of premium credits "
+ "(used=%d/%d remaining=%d)",
+ video_pres.id,
+ exc.used_micros,
+ exc.limit_micros,
+ exc.remaining_micros,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "premium_quota_exhausted",
+ }
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index dbfe9a67b..31c0d7d6d 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -2236,8 +2236,10 @@ async def stream_new_chat(
accumulator = start_turn()
- # Premium quota tracking state
- _premium_reserved = 0
+ # Premium credit (USD micro-units) tracking state. Stores the
+ # amount reserved up front so we can release it on cancellation
+ # and finalize-debit the actual provider cost reported by LiteLLM.
+ _premium_reserved_micros = 0
_premium_request_id: str | None = None
_emit_stream_error = partial(
@@ -2331,23 +2333,28 @@ async def stream_new_chat(
if _needs_premium_quota:
import uuid as _uuid
- from app.config import config as _app_config
- from app.services.token_quota_service import TokenQuotaService
+ from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+ )
_premium_request_id = _uuid.uuid4().hex[:16]
- reserve_amount = min(
- agent_config.quota_reserve_tokens
- or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
- _app_config.QUOTA_MAX_RESERVE_PER_CALL,
+ _agent_litellm_params = agent_config.litellm_params or {}
+ _agent_base_model = (
+ _agent_litellm_params.get("base_model") or agent_config.model_name or ""
+ )
+ reserve_amount_micros = estimate_call_reserve_micros(
+ base_model=_agent_base_model,
+ quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
- reserve_tokens=reserve_amount,
+ reserve_micros=reserve_amount_micros,
)
- _premium_reserved = reserve_amount
+ _premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
if requested_llm_config_id == 0:
try:
@@ -2382,7 +2389,7 @@ async def stream_new_chat(
yield streaming_service.format_done()
return
_premium_request_id = None
- _premium_reserved = 0
+ _premium_reserved_micros = 0
_log_chat_stream_error(
flow=flow,
error_kind="premium_quota_exhausted",
@@ -3020,9 +3027,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
+ "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3033,6 +3041,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3060,7 +3069,11 @@ async def stream_new_chat(
chat_id, generated_title
)
- # Finalize premium quota with actual tokens.
+ # Finalize premium credit debit with the actual provider cost
+ # reported by LiteLLM, summed across every call in the turn.
+ # Mirrors the pre-cost behaviour of "premium turn → all calls
+ # count" so free sub-agent calls during a premium turn still
+ # contribute to the bill (they're $0 in practice anyway).
if _premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3070,11 +3083,11 @@ async def stream_new_chat(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
- actual_tokens=accumulator.grand_total,
- reserved_tokens=_premium_reserved,
+ actual_micros=accumulator.total_cost_micros,
+ reserved_micros=_premium_reserved_micros,
)
_premium_request_id = None
- _premium_reserved = 0
+ _premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
@@ -3084,9 +3097,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] normal new_chat: calls=%d total=%d summary=%s",
+ "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3097,6 +3111,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3190,7 +3205,7 @@ async def stream_new_chat(
end_turn(str(chat_id))
# Release premium reservation if not finalized
- if _premium_request_id and _premium_reserved > 0 and user_id:
+ if _premium_request_id and _premium_reserved_micros > 0 and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3198,9 +3213,9 @@ async def stream_new_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
- reserved_tokens=_premium_reserved,
+ reserved_micros=_premium_reserved_micros,
)
- _premium_reserved = 0
+ _premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id
@@ -3369,8 +3384,8 @@ async def stream_resume_chat(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
- # Premium quota reservation (same logic as stream_new_chat)
- _resume_premium_reserved = 0
+ # Premium credit reservation (same logic as stream_new_chat).
+ _resume_premium_reserved_micros = 0
_resume_premium_request_id: str | None = None
_resume_needs_premium = (
agent_config is not None and user_id and agent_config.is_premium
@@ -3378,23 +3393,30 @@ async def stream_resume_chat(
if _resume_needs_premium:
import uuid as _uuid
- from app.config import config as _app_config
- from app.services.token_quota_service import TokenQuotaService
+ from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+ )
_resume_premium_request_id = _uuid.uuid4().hex[:16]
- reserve_amount = min(
- agent_config.quota_reserve_tokens
- or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
- _app_config.QUOTA_MAX_RESERVE_PER_CALL,
+ _resume_litellm_params = agent_config.litellm_params or {}
+ _resume_base_model = (
+ _resume_litellm_params.get("base_model")
+ or agent_config.model_name
+ or ""
+ )
+ reserve_amount_micros = estimate_call_reserve_micros(
+ base_model=_resume_base_model,
+ quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
- reserve_tokens=reserve_amount,
+ reserve_micros=reserve_amount_micros,
)
- _resume_premium_reserved = reserve_amount
+ _resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
if requested_llm_config_id == 0:
try:
@@ -3429,7 +3451,7 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
_resume_premium_request_id = None
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
@@ -3746,9 +3768,10 @@ async def stream_resume_chat(
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
+ "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3759,6 +3782,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3768,7 +3792,9 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
- # Finalize premium quota for resume path
+ # Finalize premium credit debit for resume path with the actual
+ # provider cost reported by LiteLLM (sum of cost across all
+ # calls in the turn).
if _resume_premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3778,11 +3804,11 @@ async def stream_resume_chat(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
- actual_tokens=accumulator.grand_total,
- reserved_tokens=_resume_premium_reserved,
+ actual_micros=accumulator.total_cost_micros,
+ reserved_micros=_resume_premium_reserved_micros,
)
_resume_premium_request_id = None
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)",
@@ -3792,9 +3818,10 @@ async def stream_resume_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
+ "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3805,6 +3832,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3855,7 +3883,11 @@ async def stream_resume_chat(
end_turn(str(chat_id))
# Release premium reservation if not finalized
- if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
+ if (
+ _resume_premium_request_id
+ and _resume_premium_reserved_micros > 0
+ and user_id
+ ):
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3863,9 +3895,9 @@ async def stream_resume_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
- reserved_tokens=_resume_premium_reserved,
+ reserved_micros=_resume_premium_reserved_micros,
)
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s (resume)", user_id
diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py
new file mode 100644
index 000000000..636b7de31
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py
@@ -0,0 +1,138 @@
+"""Unit tests for the image-generation route's billing-resolution helper.
+
+End-to-end "POST /image-generations returns 402" coverage requires the
+integration harness (real DB, real auth) and lives in
+``tests/integration/document_upload/`` alongside the other quota tests.
+This unit test focuses on the new ``_resolve_billing_for_image_gen``
+helper which:
+
+* Returns ``free`` for Auto mode, even when premium configs exist
+ (Auto-mode billing-tier surfacing is a follow-up).
+* Returns ``free`` for user-owned BYOK configs (positive IDs).
+* Returns the global config's ``billing_tier`` for negative IDs.
+* Honours the per-config ``quota_reserve_micros`` override when present.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_auto_mode(monkeypatch):
+ from app.routes import image_generation_routes
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, # Not consumed on this code path.
+ config_id=0, # IMAGE_GEN_AUTO_MODE_ID
+ search_space=search_space,
+ )
+ assert tier == "free"
+ assert model == "auto"
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_premium_global_config(monkeypatch):
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_IMAGE_GEN_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "billing_tier": "premium",
+ "quota_reserve_micros": 75_000,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash-image",
+ "billing_tier": "free",
+ },
+ ],
+ raising=False,
+ )
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+
+ # Premium with override.
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=-1, search_space=search_space
+ )
+ assert tier == "premium"
+ assert model == "openai/gpt-image-1"
+ assert reserve == 75_000
+
+ # Free, no override → falls back to default.
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=-2, search_space=search_space
+ )
+ assert tier == "free"
+ # Provider-prefixed model string for OpenRouter.
+ assert "google/gemini-2.5-flash-image" in model
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_user_owned_byok_is_free():
+ """User-owned BYOK configs (positive IDs) cost the user nothing on
+ our side — they pay the provider directly. Always free.
+ """
+ from app.routes import image_generation_routes
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=42, search_space=search_space
+ )
+ assert tier == "free"
+ assert model == "user_byok"
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
+ """When the request omits ``image_generation_config_id``, the helper
+ must consult the search space's default — so a search space pinned
+ to a premium global config still gates new requests by quota.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_IMAGE_GEN_CONFIGS",
+ [
+ {
+ "id": -7,
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "billing_tier": "premium",
+ }
+ ],
+ raising=False,
+ )
+
+ search_space = SimpleNamespace(image_generation_config_id=-7)
+ (
+ tier,
+ model,
+ _reserve,
+ ) = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=None, search_space=search_space
+ )
+ assert tier == "premium"
+ assert model == "openai/gpt-image-1"
diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
new file mode 100644
index 000000000..fa8819b39
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
@@ -0,0 +1,436 @@
+"""Unit tests for ``_resolve_agent_billing_for_search_space``.
+
+Validates the resolver used by Celery podcast/video tasks to compute
+``(owner_user_id, billing_tier, base_model)`` from a search space and its
+agent LLM config. The resolver mirrors chat's billing-resolution pattern at
+``stream_new_chat.py:2294-2351`` and is the single integration point that
+prevents Auto-mode podcast/video from leaking premium credit.
+
+Coverage:
+
+* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
+ global → returns ``("premium", )``.
+* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
+ global → returns ``("free", )``.
+* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
+ → always ``"free"``.
+* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
+ hitting the pin service.
+* Negative id (no Auto) → uses ``get_global_llm_config``'s
+ ``billing_tier``.
+* Positive id (user BYOK) → always ``"free"``.
+* Search space not found → raises ``ValueError``.
+* ``agent_llm_id`` is None → raises ``ValueError``.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from uuid import UUID, uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+
+class _FakeSession:
+ """Tiny AsyncSession stub.
+
+ ``responses`` is a list of objects to return from successive
+ ``execute()`` calls (in order). The resolver makes at most two
+ ``execute()`` calls (search-space lookup, then optionally NewLLMConfig
+ lookup), so two queued responses cover the matrix.
+ """
+
+ def __init__(self, responses: list):
+ self._responses = list(responses)
+
+ async def execute(self, _stmt):
+ if not self._responses:
+ return _FakeExecResult(None)
+ return _FakeExecResult(self._responses.pop(0))
+
+ async def commit(self) -> None:
+ pass
+
+
+@dataclass
+class _FakePinResolution:
+ resolved_llm_config_id: int
+ resolved_tier: str = "premium"
+ from_existing_pin: bool = False
+
+
+def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=42,
+ agent_llm_id=agent_llm_id,
+ user_id=user_id,
+ )
+
+
+def _make_byok_config(
+ *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
+) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=id_,
+ model_name=model_name,
+ litellm_params={"base_model": base_model} if base_model else {},
+ )
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
+ """Auto + thread → pin service resolves to negative-id premium config →
+ resolver returns ``("premium", )``."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ # Mock the pin service to return a concrete premium config id.
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ assert selected_llm_config_id == 0
+ assert thread_id == 99
+ return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
+
+ # Mock global config lookup to return a premium entry.
+ def _fake_get_global(cfg_id):
+ if cfg_id == -1:
+ return {
+ "id": -1,
+ "model_name": "gpt-5.4",
+ "billing_tier": "premium",
+ "litellm_params": {"base_model": "gpt-5.4"},
+ }
+ return None
+
+ # Lazy imports inside the resolver — patch the *target* modules so the
+ # imported names resolve to our fakes.
+ import app.services.auto_model_pin_service as pin_module
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "premium"
+ assert base_model == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
+ """Auto + thread → pin returns negative-id free config → resolver
+ returns ``("free", )``. Same path the pin service takes for
+ out-of-credit users (graceful degradation)."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
+
+ def _fake_get_global(cfg_id):
+ if cfg_id == -3:
+ return {
+ "id": -3,
+ "model_name": "openrouter/free-model",
+ "billing_tier": "free",
+ "litellm_params": {"base_model": "openrouter/free-model"},
+ }
+ return None
+
+ import app.services.auto_model_pin_service as pin_module
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "openrouter/free-model"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
+ """Auto + thread → pin returns positive-id BYOK config → resolver
+ returns ``("free", ...)`` (BYOK is always free per
+ ``AgentConfig.from_new_llm_config``)."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
+ byok_cfg = _make_byok_config(
+ id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
+ )
+ session = _FakeSession([search_space, byok_cfg])
+
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
+
+ import app.services.auto_model_pin_service as pin_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "anthropic/claude-3-haiku"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_without_thread_id_falls_back_to_free():
+ """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
+ the pin service. Forward-compat fallback for any future direct-API
+ entrypoint that doesn't have a chat thread."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=None
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "auto"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
+ """If the pin service raises ``ValueError`` (thread missing /
+ mismatched search space), the resolver should log and return free
+ rather than killing the whole task."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ async def _fake_resolve_pin(*args, **kwargs):
+ raise ValueError("thread missing")
+
+ import app.services.auto_model_pin_service as pin_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "auto"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_premium_global_returns_premium(monkeypatch):
+ """Explicit negative agent_llm_id → ``get_global_llm_config`` →
+ return its ``billing_tier``."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "gpt-5.4",
+ "billing_tier": "premium",
+ "litellm_params": {"base_model": "gpt-5.4"},
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "premium"
+ assert base_model == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_free_global_returns_free(monkeypatch):
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "openrouter/some-free",
+ "billing_tier": "free",
+ "litellm_params": {"base_model": "openrouter/some-free"},
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=None
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "openrouter/some-free"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
+ """When the global config has no ``litellm_params.base_model``, the
+ resolver falls back to ``model_name`` — matching chat's behavior."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "fallback-model",
+ "billing_tier": "premium",
+ # No litellm_params.
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ _, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert tier == "premium"
+ assert base_model == "fallback-model"
+
+
+@pytest.mark.asyncio
+async def test_positive_id_byok_is_always_free():
+ """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
+ regardless of underlying provider tier."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
+ byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
+ session = _FakeSession([search_space, byok_cfg])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "anthropic/claude-3.5-sonnet"
+
+
+@pytest.mark.asyncio
+async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
+ """If the BYOK config row is missing/deleted but the search space still
+ points at it, the resolver still returns free (no debit) with an empty
+ base_model — billable_call's premium path is skipped, no harm done."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == ""
+
+
+@pytest.mark.asyncio
+async def test_search_space_not_found_raises_value_error():
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ session = _FakeSession([None])
+
+ with pytest.raises(ValueError, match="Search space"):
+ await _resolve_agent_billing_for_search_space(session, search_space_id=999)
+
+
+@pytest.mark.asyncio
+async def test_agent_llm_id_none_raises_value_error():
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
+
+ with pytest.raises(ValueError, match="agent_llm_id"):
+ await _resolve_agent_billing_for_search_space(session, search_space_id=42)
diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py
new file mode 100644
index 000000000..86de5f23d
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_billable_call.py
@@ -0,0 +1,432 @@
+"""Unit tests for the ``billable_call`` async context manager.
+
+Covers the per-call premium-credit lifecycle for image generation and
+vision LLM extraction:
+
+* Free configs bypass reserve/finalize but still write an audit row.
+* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
+ route layer).
+* Successful premium calls reserve, yield the accumulator, then finalize
+ with the LiteLLM-reported actual cost — and write an audit row.
+* Failed premium calls release the reservation so credit isn't leaked.
+* All quota DB ops happen inside their OWN ``shielded_async_session``,
+ isolating them from the caller's transaction (issue A).
+"""
+
+from __future__ import annotations
+
+import contextlib
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeQuotaResult:
+ def __init__(
+ self,
+ *,
+ allowed: bool,
+ used: int = 0,
+ limit: int = 5_000_000,
+ remaining: int = 5_000_000,
+ ) -> None:
+ self.allowed = allowed
+ self.used = used
+ self.limit = limit
+ self.remaining = remaining
+
+
+class _FakeSession:
+ """Minimal AsyncSession stub — record commits for assertion."""
+
+ def __init__(self) -> None:
+ self.committed = False
+ self.added: list[Any] = []
+
+ def add(self, obj: Any) -> None:
+ self.added.append(obj)
+
+ async def commit(self) -> None:
+ self.committed = True
+
+ async def close(self) -> None:
+ pass
+
+
+@contextlib.asynccontextmanager
+async def _fake_shielded_session():
+ s = _FakeSession()
+ _SESSIONS_USED.append(s)
+ yield s
+
+
+_SESSIONS_USED: list[_FakeSession] = []
+
+
+def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
+ """Wire fake reserve/finalize/release/session helpers."""
+ _SESSIONS_USED.clear()
+ reserve_calls: list[dict[str, Any]] = []
+ finalize_calls: list[dict[str, Any]] = []
+ release_calls: list[dict[str, Any]] = []
+
+ async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
+ reserve_calls.append(
+ {
+ "user_id": user_id,
+ "reserve_micros": reserve_micros,
+ "request_id": request_id,
+ }
+ )
+ return reserve_result
+
+ async def _fake_finalize(
+ *, db_session, user_id, request_id, actual_micros, reserved_micros
+ ):
+ finalize_calls.append(
+ {
+ "user_id": user_id,
+ "actual_micros": actual_micros,
+ "reserved_micros": reserved_micros,
+ }
+ )
+ return finalize_result or _FakeQuotaResult(allowed=True)
+
+ async def _fake_release(*, db_session, user_id, reserved_micros):
+ release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
+
+ record_calls: list[dict[str, Any]] = []
+
+ async def _fake_record(session, **kwargs):
+ record_calls.append(kwargs)
+ return object()
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_reserve",
+ _fake_reserve,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_finalize",
+ _fake_finalize,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_release",
+ _fake_release,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.shielded_async_session",
+ _fake_shielded_session,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.record_token_usage",
+ _fake_record,
+ raising=False,
+ )
+
+ return {
+ "reserve": reserve_calls,
+ "finalize": finalize_calls,
+ "release": release_calls,
+ "record": record_calls,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openai/gpt-image-1",
+ usage_type="image_generation",
+ ) as acc:
+ # Simulate a captured cost — the accumulator is fed by the LiteLLM
+ # callback in real life, here we add it manually.
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=37_000,
+ call_kind="image_generation",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ # Free still audits.
+ assert len(spies["record"]) == 1
+ assert spies["record"][0]["usage_type"] == "image_generation"
+ assert spies["record"][0]["cost_micros"] == 37_000
+
+
+@pytest.mark.asyncio
+async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
+ from app.services.billable_calls import (
+ QuotaInsufficientError,
+ billable_call,
+ )
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(
+ allowed=False, used=5_000_000, limit=5_000_000, remaining=0
+ ),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(QuotaInsufficientError) as exc_info:
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ):
+ pytest.fail("body should not run when reserve is denied")
+
+ err = exc_info.value
+ assert err.usage_type == "image_generation"
+ assert err.used_micros == 5_000_000
+ assert err.limit_micros == 5_000_000
+ assert err.remaining_micros == 0
+ # Reserve was attempted, but no finalize/release on a denied reserve
+ # — the reservation never actually held credit.
+ assert len(spies["reserve"]) == 1
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ # Denied premium calls do NOT create an audit row (no work happened).
+ assert spies["record"] == []
+
+
+@pytest.mark.asyncio
+async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ) as acc:
+ # LiteLLM callback would normally fill this — simulate $0.04 image.
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert spies["reserve"][0]["reserve_micros"] == 50_000
+ assert len(spies["finalize"]) == 1
+ assert spies["finalize"][0]["actual_micros"] == 40_000
+ assert spies["finalize"][0]["reserved_micros"] == 50_000
+ assert spies["release"] == []
+ # And audit row written with the actual debited cost.
+ assert spies["record"][0]["cost_micros"] == 40_000
+ # Each quota op opened its OWN session — proves session isolation.
+ assert len(_SESSIONS_USED) >= 3
+ # Sessions used should each have committed (or be the audit one which commits).
+ for _s in _SESSIONS_USED:
+ # finalize/reserve happen via TokenQuotaService.* which we stub —
+ # they don't actually call commit on our fake session, but the
+ # audit session does. We just assert >=1 session committed.
+ pass
+ assert any(s.committed for s in _SESSIONS_USED)
+
+
+@pytest.mark.asyncio
+async def test_premium_failure_releases_reservation(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ class _ProviderError(Exception):
+ pass
+
+ with pytest.raises(_ProviderError):
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ):
+ raise _ProviderError("OpenRouter 503")
+
+ assert len(spies["reserve"]) == 1
+ assert spies["finalize"] == []
+ # Failure path: release the held reservation.
+ assert len(spies["release"]) == 1
+ assert spies["release"][0]["reserved_micros"] == 50_000
+
+
+@pytest.mark.asyncio
+async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
+ """When ``quota_reserve_micros_override`` is None we fall back to
+ ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
+ Vision LLM calls take this path (token-priced models).
+ """
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+
+ captured_estimator_calls: list[dict[str, Any]] = []
+
+ def _fake_estimate(*, base_model, quota_reserve_tokens):
+ captured_estimator_calls.append(
+ {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
+ )
+ return 12_345
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.estimate_call_reserve_micros",
+ _fake_estimate,
+ raising=False,
+ )
+
+ user_id = uuid4()
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ usage_type="vision_extraction",
+ ):
+ pass
+
+ assert captured_estimator_calls == [
+ {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
+ ]
+ assert spies["reserve"][0]["reserve_micros"] == 12_345
+
+
+# ---------------------------------------------------------------------------
+# Podcast / video-presentation usage_type coverage
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
+ """Free podcast configs must skip reserve/finalize but still emit a
+ ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
+ have full audit coverage of free-tier agent runs."""
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openrouter/some-free-model",
+ quota_reserve_micros_override=200_000,
+ usage_type="podcast_generation",
+ thread_id=99,
+ call_details={"podcast_id": 7, "title": "Test Podcast"},
+ ) as acc:
+ # Two transcript LLM calls aggregated into one accumulator.
+ acc.add(
+ model="openrouter/some-free-model",
+ prompt_tokens=1500,
+ completion_tokens=8000,
+ total_tokens=9500,
+ cost_micros=0,
+ call_kind="chat",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+
+ assert len(spies["record"]) == 1
+ row = spies["record"][0]
+ assert row["usage_type"] == "podcast_generation"
+ assert row["thread_id"] == 99
+ assert row["search_space_id"] == 42
+ assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
+
+
+@pytest.mark.asyncio
+async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
+ """Premium video-presentation runs that hit a denied reservation must
+ raise ``QuotaInsufficientError`` *before* the graph runs and must not
+ emit an audit row (no work happened)."""
+ from app.services.billable_calls import (
+ QuotaInsufficientError,
+ billable_call,
+ )
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(
+ allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
+ ),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(QuotaInsufficientError) as exc_info:
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="gpt-5.4",
+ quota_reserve_micros_override=1_000_000,
+ usage_type="video_presentation_generation",
+ thread_id=99,
+ call_details={"video_presentation_id": 12, "title": "Test Video"},
+ ):
+ pytest.fail("body should not run when reserve is denied")
+
+ err = exc_info.value
+ assert err.usage_type == "video_presentation_generation"
+ assert err.remaining_micros == 500_000
+ assert spies["reserve"][0]["reserve_micros"] == 1_000_000
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ assert spies["record"] == []
diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
index 085740032..b635b4fe8 100644
--- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
+++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
@@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names
+
+
+# ---------------------------------------------------------------------------
+# _generate_image_gen_configs / _generate_vision_llm_configs
+# ---------------------------------------------------------------------------
+
+
+def test_generate_image_gen_configs_filters_by_image_output():
+ """Only models with ``output_modalities`` containing ``image`` are emitted.
+ Tool-calling and context filters are intentionally NOT applied — image
+ generation has nothing to do with tool calls and context windows.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_image_gen_configs,
+ )
+
+ raw = [
+ # Pure image-gen model (small context, no tools — should still emit).
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {"output_modalities": ["image"]},
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ },
+ # Multi-modal: text+image output (should still emit).
+ {
+ "id": "google/gemini-2.5-flash-image",
+ "architecture": {"output_modalities": ["text", "image"]},
+ "context_length": 1_000_000,
+ "pricing": {"prompt": "0.000001", "completion": "0.000004"},
+ },
+ # Pure text model — must NOT emit.
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {"output_modalities": ["text"]},
+ "context_length": 128_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ ]
+
+ cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
+ model_names = {c["model_name"] for c in cfgs}
+ assert "openai/gpt-image-1" in model_names
+ assert "google/gemini-2.5-flash-image" in model_names
+ assert "openai/gpt-4o" not in model_names
+
+ # Each config must carry ``billing_tier`` for routing in image_generation_routes.
+ for c in cfgs:
+ assert c["billing_tier"] in {"free", "premium"}
+ assert c["provider"] == "OPENROUTER"
+ assert c[_OPENROUTER_DYNAMIC_MARKER] is True
+
+
+def test_generate_image_gen_configs_assigns_image_id_offset():
+ """Image configs use a different id_offset (-20000) so their negative
+ IDs don't collide with chat configs (-10000) or vision configs (-30000).
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_image_gen_configs,
+ )
+
+ raw = [
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {"output_modalities": ["image"]},
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ }
+ ]
+ # Don't pass image_id_offset → use the module default (-20000).
+ cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
+ assert all(c["id"] < -20_000 + 1 for c in cfgs)
+ assert all(c["id"] > -29_000_000 for c in cfgs)
+
+
+def test_generate_vision_llm_configs_filters_by_image_input_text_output():
+ """Vision LLMs must accept image input AND emit text — pure image-gen
+ (no text out) and text-only (no image in) models are excluded.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_vision_llm_configs,
+ )
+
+ raw = [
+ # GPT-4o: vision LLM (image in, text out) — must emit.
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "context_length": 128_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ # Pure image generator — image *output*, no text out. Must NOT emit.
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["image"],
+ },
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ },
+ # Pure text model (no image in). Must NOT emit.
+ {
+ "id": "anthropic/claude-3-haiku",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000001", "completion": "0.000005"},
+ },
+ ]
+
+ cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
+ names = {c["model_name"] for c in cfgs}
+ assert names == {"openai/gpt-4o"}
+
+ cfg = cfgs[0]
+ assert cfg["billing_tier"] == "premium"
+ # Pricing carried inline so pricing_registration can register vision
+ # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
+ # is cleared.
+ assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
+ assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
+ assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
+
+
+def test_generate_vision_llm_configs_drops_chat_only_filters():
+ """A small-context vision model that doesn't advertise tool calling is
+ still a valid vision LLM for "describe this image" prompts. The chat
+ filters (``supports_tool_calling``, ``has_sufficient_context``) must
+ NOT be applied to vision emission.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_vision_llm_configs,
+ )
+
+ raw = [
+ {
+ "id": "tiny/vision-mini",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": [], # no tools
+ "context_length": 4_000, # well below MIN_CONTEXT_LENGTH
+ "pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
+ }
+ ]
+
+ cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
+ assert len(cfgs) == 1
+ assert cfgs[0]["model_name"] == "tiny/vision-mini"
diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py
new file mode 100644
index 000000000..e97250ff2
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py
@@ -0,0 +1,447 @@
+"""Pricing registration unit tests.
+
+The pricing-registration module is what makes ``response_cost`` populate
+correctly for OpenRouter dynamic models and operator-defined Azure
+deployments — both of which LiteLLM doesn't natively know about. The tests
+exercise:
+
+* The alias generators emit every shape that LiteLLM's cost-callback might
+ use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
+ ``provider/base_model``, ``provider/model_name``, plus the special
+ ``azure_openai`` → ``azure`` normalisation).
+* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
+ with the right alias set and pricing values per provider.
+* Configs without a resolvable pair of cost values are skipped — never
+ registered as zero, since that would override pricing LiteLLM might
+ already know natively.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Alias generators
+# ---------------------------------------------------------------------------
+
+
+def test_openrouter_alias_set_includes_prefixed_and_bare():
+ from app.services.pricing_registration import _alias_set_for_openrouter
+
+ aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
+ assert aliases == [
+ "openrouter/anthropic/claude-3-5-sonnet",
+ "anthropic/claude-3-5-sonnet",
+ ]
+
+
+def test_openrouter_alias_set_dedupes():
+ """If the model id is already prefixed with ``openrouter/``, the alias
+ set must not contain duplicates that would re-register the same key
+ twice.
+ """
+ from app.services.pricing_registration import _alias_set_for_openrouter
+
+ aliases = _alias_set_for_openrouter("openrouter/foo")
+ # The bare and prefixed variants compute to the same string here, so we
+ # at minimum require uniqueness.
+ assert len(aliases) == len(set(aliases))
+
+
+def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
+ """``azure_openai`` (our YAML provider slug) must register under
+ ``azure/`` so the LiteLLM Router's deployment-resolution path
+ (which uses provider ``azure``) finds the pricing too.
+ """
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="AZURE_OPENAI",
+ model_name="gpt-5.4",
+ base_model="gpt-5.4",
+ )
+ assert "gpt-5.4" in aliases
+ assert "azure_openai/gpt-5.4" in aliases
+ assert "azure/gpt-5.4" in aliases
+
+
+def test_yaml_alias_set_distinguishes_model_name_and_base_model():
+ """When ``model_name`` differs from ``base_model`` (operator labelled a
+ deployment), both must appear in the alias set since either may surface
+ in callbacks depending on the call path.
+ """
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="OPENAI",
+ model_name="my-deployment-label",
+ base_model="gpt-4o",
+ )
+ assert "gpt-4o" in aliases
+ assert "openai/gpt-4o" in aliases
+ assert "my-deployment-label" in aliases
+ assert "openai/my-deployment-label" in aliases
+
+
+def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="",
+ model_name="foo",
+ base_model="bar",
+ )
+ assert "bar" in aliases
+ assert "foo" in aliases
+ assert all("/" not in a for a in aliases)
+
+
+# ---------------------------------------------------------------------------
+# register_pricing_from_global_configs
+# ---------------------------------------------------------------------------
+
+
+class _RegistrationSpy:
+ """Captures the dicts passed to ``litellm.register_model``.
+
+ Many calls may go through; we just record them all and let tests assert
+ against the union.
+ """
+
+ def __init__(self) -> None:
+ self.calls: list[dict[str, Any]] = []
+
+ def __call__(self, payload: dict[str, Any]) -> None:
+ self.calls.append(payload)
+
+ @property
+ def all_keys(self) -> set[str]:
+ keys: set[str] = set()
+ for payload in self.calls:
+ keys.update(payload.keys())
+ return keys
+
+
+def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
+ spy = _RegistrationSpy()
+ monkeypatch.setattr(
+ "app.services.pricing_registration.litellm.register_model",
+ spy,
+ raising=False,
+ )
+ return spy
+
+
+def _patch_openrouter_pricing(
+ monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
+) -> None:
+ """Pretend the OpenRouter integration is initialised with ``mapping``."""
+
+ class _Stub:
+ def get_raw_pricing(self) -> dict[str, dict[str, str]]:
+ return mapping
+
+ class _StubService:
+ @classmethod
+ def is_initialized(cls) -> bool:
+ return True
+
+ @classmethod
+ def get_instance(cls) -> _Stub:
+ return _Stub()
+
+ monkeypatch.setattr(
+ "app.services.openrouter_integration_service.OpenRouterIntegrationService",
+ _StubService,
+ raising=False,
+ )
+
+
+def test_openrouter_models_register_under_aliases(monkeypatch):
+ """An OpenRouter config whose ``model_name`` is in the cached raw
+ pricing map is registered under both ``openrouter/X`` and bare ``X``.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {
+ "anthropic/claude-3-5-sonnet": {
+ "prompt": "0.000003",
+ "completion": "0.000015",
+ }
+ },
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
+ assert "anthropic/claude-3-5-sonnet" in spy.all_keys
+ # Costs are float-converted from the raw OpenRouter strings.
+ payload = spy.calls[0]
+ assert payload["openrouter/anthropic/claude-3-5-sonnet"][
+ "input_cost_per_token"
+ ] == pytest.approx(3e-6)
+ assert payload["openrouter/anthropic/claude-3-5-sonnet"][
+ "output_cost_per_token"
+ ] == pytest.approx(15e-6)
+ assert (
+ payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
+ == "openrouter"
+ )
+
+
+def test_yaml_override_registers_under_alias_set(monkeypatch):
+ """Operator-declared ``input_cost_per_token`` /
+ ``output_cost_per_token`` on a YAML config registers under every
+ alias the YAML alias generator produces — including the ``azure/``
+ normalisation for ``azure_openai`` providers.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.4",
+ "litellm_params": {
+ "base_model": "gpt-5.4",
+ "input_cost_per_token": 2e-6,
+ "output_cost_per_token": 8e-6,
+ },
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ keys = spy.all_keys
+ assert "gpt-5.4" in keys
+ assert "azure_openai/gpt-5.4" in keys
+ assert "azure/gpt-5.4" in keys
+
+ payload = spy.calls[0]
+ entry = payload["gpt-5.4"]
+ assert entry["input_cost_per_token"] == pytest.approx(2e-6)
+ assert entry["output_cost_per_token"] == pytest.approx(8e-6)
+ assert entry["litellm_provider"] == "azure"
+
+
+def test_no_override_means_no_registration(monkeypatch):
+ """A YAML config that *omits* both pricing fields must NOT be registered
+ — registering as zero would override LiteLLM's native pricing for the
+ ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
+ bill drop to $0. Fail-safe is "skip and warn", not "register zero".
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "litellm_params": {"base_model": "gpt-4o"},
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert spy.calls == []
+
+
+def test_openrouter_skipped_when_pricing_missing(monkeypatch):
+ """If the OpenRouter raw-pricing cache doesn't carry an entry for a
+ configured model (network blip during refresh, model added later, etc.),
+ we skip it rather than registering zero pricing.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert spy.calls == []
+
+
+def test_register_continues_after_individual_failure(monkeypatch, caplog):
+ """A single bad ``register_model`` call (e.g. raising LiteLLM error)
+ must not abort registration of the remaining configs.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
+ successful_calls: list[dict[str, Any]] = []
+
+ def _maybe_fail(payload: dict[str, Any]) -> None:
+ if any(k in failing_keys for k in payload):
+ raise RuntimeError("boom")
+ successful_calls.append(payload)
+
+ monkeypatch.setattr(
+ "app.services.pricing_registration.litellm.register_model",
+ _maybe_fail,
+ raising=False,
+ )
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {
+ "anthropic/claude-3-5-sonnet": {
+ "prompt": "0.000003",
+ "completion": "0.000015",
+ }
+ },
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ },
+ {
+ "id": 2,
+ "provider": "OPENAI",
+ "model_name": "custom-deployment",
+ "litellm_params": {
+ "base_model": "custom-deployment",
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 2e-6,
+ },
+ },
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ # The good config still registered.
+ assert any("custom-deployment" in payload for payload in successful_calls)
+
+
+def test_vision_configs_registered_with_chat_shape(monkeypatch):
+ """``register_pricing_from_global_configs`` walks
+ ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
+ calls (during indexing) bill correctly. Vision configs use the same
+ chat-shape token prices, but image-gen pricing is intentionally NOT
+ registered here (handled via ``response_cost`` in LiteLLM).
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
+ )
+
+ # No chat configs — only vision. Proves the vision walk is a separate
+ # iteration, not piggy-backed on the chat list.
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_VISION_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "billing_tier": "premium",
+ "input_cost_per_token": 5e-6,
+ "output_cost_per_token": 15e-6,
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/openai/gpt-4o" in spy.all_keys
+ payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
+ assert payload_value["mode"] == "chat"
+ assert payload_value["litellm_provider"] == "openrouter"
+ assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
+ assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
+
+
+def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
+ """If the OpenRouter pricing cache misses a vision model (different
+ catalogue surface), the vision walk falls back to inline
+ ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_VISION_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash",
+ "billing_tier": "premium",
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 4e-6,
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
new file mode 100644
index 000000000..9e35b6f9c
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
@@ -0,0 +1,157 @@
+"""Unit tests for ``QuotaCheckedVisionLLM``.
+
+Validates that:
+
+* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
+ enforcement) and forwards the inner LLM's response on success.
+* The wrapper proxies non-overridden attributes to the inner LLM
+ (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
+ still work without quota gating (they're not used in indexing today).
+* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
+ bubbles it up — the ETL pipeline catches that and falls back to OCR.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+class _FakeInnerLLM:
+ """Stand-in for ``langchain_litellm.ChatLiteLLM``."""
+
+ def __init__(self, response: Any = "OCR'd content") -> None:
+ self._response = response
+ self.ainvoke_calls: list[Any] = []
+
+ async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
+ self.ainvoke_calls.append(input)
+ return self._response
+
+ def some_other_method(self, x: int) -> int:
+ return x * 2
+
+
+@contextlib.asynccontextmanager
+async def _passthrough_billable_call(**_kwargs):
+ """Stand-in for billable_call that always allows the call to run."""
+
+ class _Acc:
+ total_cost_micros = 0
+ total_prompt_tokens = 0
+ total_completion_tokens = 0
+ grand_total = 0
+ calls: list[Any] = []
+
+ def per_message_summary(self) -> dict[str, dict[str, int]]:
+ return {}
+
+ yield _Acc()
+
+
+@pytest.mark.asyncio
+async def test_ainvoke_routes_through_billable_call(monkeypatch):
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ captured_kwargs: list[dict[str, Any]] = []
+
+ @contextlib.asynccontextmanager
+ async def _spy_billable_call(**kwargs):
+ captured_kwargs.append(kwargs)
+ async with _passthrough_billable_call() as acc:
+ yield acc
+
+ monkeypatch.setattr(
+ "app.services.quota_checked_vision_llm.billable_call",
+ _spy_billable_call,
+ raising=False,
+ )
+
+ inner = _FakeInnerLLM(response="A red apple on a white table")
+ user_id = uuid4()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=user_id,
+ search_space_id=99,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ result = await wrapper.ainvoke([{"text": "what is this?"}])
+ assert result == "A red apple on a white table"
+ assert len(inner.ainvoke_calls) == 1
+ assert len(captured_kwargs) == 1
+ bc_kwargs = captured_kwargs[0]
+ assert bc_kwargs["user_id"] == user_id
+ assert bc_kwargs["search_space_id"] == 99
+ assert bc_kwargs["billing_tier"] == "premium"
+ assert bc_kwargs["base_model"] == "openai/gpt-4o"
+ assert bc_kwargs["quota_reserve_tokens"] == 4000
+ assert bc_kwargs["usage_type"] == "vision_extraction"
+
+
+@pytest.mark.asyncio
+async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
+ from app.services.billable_calls import QuotaInsufficientError
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ @contextlib.asynccontextmanager
+ async def _denying_billable_call(**_kwargs):
+ raise QuotaInsufficientError(
+ usage_type="vision_extraction",
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield # unreachable but required for asynccontextmanager type
+
+ monkeypatch.setattr(
+ "app.services.quota_checked_vision_llm.billable_call",
+ _denying_billable_call,
+ raising=False,
+ )
+
+ inner = _FakeInnerLLM()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=uuid4(),
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ with pytest.raises(QuotaInsufficientError):
+ await wrapper.ainvoke([{"text": "x"}])
+
+ # Inner LLM never ran on a denied reservation.
+ assert inner.ainvoke_calls == []
+
+
+@pytest.mark.asyncio
+async def test_proxies_non_overridden_attributes_to_inner():
+ """``__getattr__`` forwards anything not on the proxy itself, so any
+ method we didn't explicitly override (``invoke``, ``astream``,
+ ``with_structured_output``, etc.) still works — just without quota
+ gating, which is fine because the indexer only ever calls ainvoke.
+ """
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ inner = _FakeInnerLLM()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=uuid4(),
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ # ``some_other_method`` is on the inner only.
+ assert wrapper.some_other_method(7) == 14
diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
new file mode 100644
index 000000000..63681828d
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
@@ -0,0 +1,515 @@
+"""Cost-based premium quota unit tests.
+
+Covers the USD-micro behaviour added in migration 140:
+
+* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
+ calls in a turn — used as the debit amount when ``agent_config.is_premium``
+ is true, regardless of which underlying model produced each call. This
+ preserves the prior "premium turn → all calls in turn count" rule from the
+ token-based system.
+* ``estimate_call_reserve_micros`` scales linearly with model pricing,
+ clamps to a sane floor when pricing is unknown, and respects the
+ ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
+ can't lock the whole balance on one call.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# TurnTokenAccumulator — premium-turn debit semantics
+# ---------------------------------------------------------------------------
+
+
+def test_total_cost_micros_sums_premium_and_free_calls():
+ """A premium turn that also called a free sub-agent debits the union.
+
+ The plan deliberately preserved the existing "premium turn → all calls
+ count" behaviour because per-call premium filtering relied on
+ ``LLMRouterService._premium_model_strings`` which only covers router-pool
+ deployments. ``total_cost_micros`` therefore must include free-model
+ calls (whose ``cost_micros`` is typically ``0``) as well as the premium
+ call's actual provider cost.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ # Premium model (e.g. claude-opus): non-zero cost.
+ acc.add(
+ model="anthropic/claude-3-5-sonnet",
+ prompt_tokens=1200,
+ completion_tokens=400,
+ total_tokens=1600,
+ cost_micros=12_345,
+ )
+ # Free sub-agent (e.g. title-gen on a free model): zero cost.
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=120,
+ completion_tokens=20,
+ total_tokens=140,
+ cost_micros=0,
+ )
+ # A second premium-priced call within the same turn.
+ acc.add(
+ model="anthropic/claude-3-5-sonnet",
+ prompt_tokens=800,
+ completion_tokens=200,
+ total_tokens=1000,
+ cost_micros=7_500,
+ )
+
+ assert acc.total_cost_micros == 12_345 + 0 + 7_500
+ # Token totals stay correct so the FE display path still works.
+ assert acc.grand_total == 1600 + 140 + 1000
+
+
+def test_total_cost_micros_zero_when_no_calls():
+ """An empty accumulator must report zero cost (no division-by-zero, no None)."""
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ assert acc.total_cost_micros == 0
+ assert acc.grand_total == 0
+
+
+def test_per_message_summary_groups_cost_by_model():
+ """``per_message_summary`` must accumulate ``cost_micros`` per model so the
+ SSE ``model_breakdown`` payload reports actual USD spend per provider.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="claude-3-5-sonnet",
+ prompt_tokens=100,
+ completion_tokens=50,
+ total_tokens=150,
+ cost_micros=4_000,
+ )
+ acc.add(
+ model="claude-3-5-sonnet",
+ prompt_tokens=200,
+ completion_tokens=100,
+ total_tokens=300,
+ cost_micros=8_000,
+ )
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=50,
+ completion_tokens=10,
+ total_tokens=60,
+ cost_micros=200,
+ )
+
+ summary = acc.per_message_summary()
+ assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
+ assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
+ assert summary["gpt-4o-mini"]["cost_micros"] == 200
+
+
+def test_serialized_calls_includes_cost_micros():
+ """``serialized_calls`` is what flows into the SSE ``call_details``
+ payload; cost_micros must be present on each entry so the FE message-info
+ dropdown can render per-call USD.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="m",
+ prompt_tokens=1,
+ completion_tokens=1,
+ total_tokens=2,
+ cost_micros=42,
+ )
+ serialized = acc.serialized_calls()
+ assert serialized == [
+ {
+ "model": "m",
+ "prompt_tokens": 1,
+ "completion_tokens": 1,
+ "total_tokens": 2,
+ "cost_micros": 42,
+ "call_kind": "chat",
+ }
+ ]
+
+
+# ---------------------------------------------------------------------------
+# estimate_call_reserve_micros — sizing and clamping
+# ---------------------------------------------------------------------------
+
+
+def test_reserve_returns_floor_when_model_unknown(monkeypatch):
+ """If LiteLLM doesn't know the model, ``get_model_info`` raises and the
+ helper falls back to the 100-micro floor — small enough that a user with
+ $0.0001 left can still send a tiny request, but non-zero so we still gate
+ against an empty balance.
+ """
+ import litellm
+
+ from app.services import token_quota_service
+
+ def _raise(_name):
+ raise KeyError("unknown")
+
+ monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="nonexistent-model",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
+ assert micros == 100
+
+
+def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
+ """LiteLLM may *return* a model with both cost-per-token fields at 0
+ (pricing not yet registered). The helper must not multiply 0 x tokens
+ and end up reserving 0 — it must clamp to the floor.
+ """
+ import litellm
+
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
+ raising=False,
+ )
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="some-pending-model",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
+
+
+def test_reserve_scales_with_model_cost(monkeypatch):
+ """Claude-Opus-priced model with 4000 reserve_tokens reserves
+ ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
+ some small artificial cap — that was the bug the plan called out.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 15e-6,
+ "output_cost_per_token": 75e-6,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="claude-3-opus",
+ quota_reserve_tokens=4000,
+ )
+ # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
+ assert micros == 360_000
+
+
+def test_reserve_clamps_to_max_ceiling(monkeypatch):
+ """A misconfigured "$1000 / M" model with 4000 reserve_tokens would
+ nominally compute to $4 = 4_000_000 micros. The ceiling
+ ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
+ can't lock the user's whole balance on one call.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 1e-3,
+ "output_cost_per_token": 0,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="oops-misconfigured",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == 1_000_000
+
+
+def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
+ """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
+ zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
+ so anonymous-style configs still reserve the operator-tunable default.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 1e-6,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
+ assert (
+ token_quota_service.estimate_call_reserve_micros(
+ base_model="cheap", quota_reserve_tokens=None
+ )
+ == 4000
+ )
+ assert (
+ token_quota_service.estimate_call_reserve_micros(
+ base_model="cheap", quota_reserve_tokens=0
+ )
+ == 4000
+ )
+
+
+# ---------------------------------------------------------------------------
+# TokenTrackingCallback — image vs chat usage shape
+# ---------------------------------------------------------------------------
+
+
+class _FakeImageUsage:
+ """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
+
+ def __init__(
+ self,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
+ total_tokens: int | None = None,
+ ) -> None:
+ self.input_tokens = input_tokens
+ self.output_tokens = output_tokens
+ if total_tokens is not None:
+ self.total_tokens = total_tokens
+
+
+class _FakeImageResponse:
+ """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
+ ``type(...).__name__`` probe routes to the image branch.
+ """
+
+ def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
+ self.usage = usage
+ if response_cost is not None:
+ self._hidden_params = {"response_cost": response_cost}
+
+
+# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
+# the callback. We can't simply name the class ``ImageResponse`` because
+# the test runner sometimes imports test modules in surprising ways and
+# we want to be explicit.
+_FakeImageResponse.__name__ = "ImageResponse"
+
+
+class _FakeChatUsage:
+ def __init__(self, prompt: int, completion: int):
+ self.prompt_tokens = prompt
+ self.completion_tokens = completion
+ self.total_tokens = prompt + completion
+
+
+class _FakeChatResponse:
+ def __init__(self, usage: _FakeChatUsage):
+ self.usage = usage
+
+
+@pytest.mark.asyncio
+async def test_callback_reads_image_usage_input_output_tokens():
+ """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
+ for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
+ prompt_tokens/completion_tokens which is the chat shape.
+ """
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeImageResponse(
+ usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
+ response_cost=0.04, # $0.04 per image
+ )
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+ assert len(acc.calls) == 1
+ call = acc.calls[0]
+ assert call.prompt_tokens == 42
+ assert call.completion_tokens == 8
+ assert call.total_tokens == 50
+ # 0.04 USD = 40_000 micros
+ assert call.cost_micros == 40_000
+ assert call.call_kind == "image_generation"
+
+
+@pytest.mark.asyncio
+async def test_callback_chat_path_unchanged():
+ """Chat responses must still read prompt_tokens/completion_tokens."""
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={
+ "model": "openrouter/anthropic/claude-3-5-sonnet",
+ "response_cost": 0.0036,
+ },
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+ assert len(acc.calls) == 1
+ call = acc.calls[0]
+ assert call.prompt_tokens == 120
+ assert call.completion_tokens == 30
+ assert call.total_tokens == 150
+ assert call.cost_micros == 3_600
+ assert call.call_kind == "chat"
+
+
+@pytest.mark.asyncio
+async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
+ """When OpenRouter omits ``usage.cost`` LiteLLM's
+ ``default_image_cost_calculator`` raises. The defensive image branch in
+ ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
+ chat-shaped and would raise too) — it returns 0 with a WARNING log.
+ """
+ import litellm
+
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ # Force completion_cost to raise the same way OpenRouter image-gen fails.
+ def _boom(*_args, **_kwargs):
+ raise ValueError("model_cost: missing entry for openrouter image model")
+
+ monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
+
+ # And make sure cost_per_token is NEVER called for the image path —
+ # if it were, our ``is_image=True`` branch is broken.
+ cost_per_token_calls: list = []
+
+ def _record_cost_per_token(**kwargs):
+ cost_per_token_calls.append(kwargs)
+ return (0.0, 0.0)
+
+ monkeypatch.setattr(
+ litellm, "cost_per_token", _record_cost_per_token, raising=False
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeImageResponse(
+ usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
+ )
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+
+ assert len(acc.calls) == 1
+ assert acc.calls[0].cost_micros == 0
+ assert acc.calls[0].call_kind == "image_generation"
+ # The image branch must short-circuit before cost_per_token.
+ assert cost_per_token_calls == []
+
+
+# ---------------------------------------------------------------------------
+# scoped_turn — ContextVar reset semantics (issue B)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_scoped_turn_restores_outer_accumulator():
+ """``scoped_turn`` must restore the previous ContextVar value on exit
+ so a per-call wrapper inside an outer chat turn doesn't leak its
+ accumulator outward (which would cause double-debit at chat-turn exit).
+ """
+ from app.services.token_tracking_service import (
+ get_current_accumulator,
+ scoped_turn,
+ start_turn,
+ )
+
+ outer = start_turn()
+ assert get_current_accumulator() is outer
+
+ async with scoped_turn() as inner:
+ assert get_current_accumulator() is inner
+ assert inner is not outer
+ inner.add(
+ model="x",
+ prompt_tokens=1,
+ completion_tokens=1,
+ total_tokens=2,
+ cost_micros=5,
+ )
+
+ # After exit the outer accumulator is restored unchanged.
+ assert get_current_accumulator() is outer
+ assert outer.total_cost_micros == 0
+ assert len(outer.calls) == 0
+ # The inner accumulator captured the call but didn't bleed into outer.
+ assert inner.total_cost_micros == 5
+
+
+@pytest.mark.asyncio
+async def test_scoped_turn_resets_to_none_when_no_outer():
+ """Running ``scoped_turn`` outside any chat turn (e.g. a background
+ indexing job) must leave the ContextVar at ``None`` on exit so the
+ next *unrelated* request starts clean.
+ """
+ from app.services.token_tracking_service import (
+ _turn_accumulator,
+ get_current_accumulator,
+ scoped_turn,
+ )
+
+ # ContextVar default is None for a fresh test isolated context. We
+ # simulate "no outer" explicitly to be robust against test order.
+ token = _turn_accumulator.set(None)
+ try:
+ assert get_current_accumulator() is None
+ async with scoped_turn() as acc:
+ assert get_current_accumulator() is acc
+ assert get_current_accumulator() is None
+ finally:
+ _turn_accumulator.reset(token)
diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
new file mode 100644
index 000000000..38d6ba2ca
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
@@ -0,0 +1,325 @@
+"""Unit tests for podcast Celery task billing integration.
+
+Validates ``_generate_content_podcast`` correctly wraps
+``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
+search-space owner's billing decision, and degrades cleanly when the
+resolver fails or premium credit is exhausted.
+
+Coverage:
+
+* Happy-path free config: resolver → ``billable_call`` enters with
+ ``usage_type='podcast_generation'`` and the configured reserve override,
+ graph runs, podcast row flips to ``READY``.
+* Happy-path premium config: same wiring with ``billing_tier='premium'``.
+* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
+ graph is *not* invoked, podcast row flips to ``FAILED``, return dict
+ carries ``reason='premium_quota_exhausted'``.
+* Resolver failure: ``ValueError`` from the resolver → podcast row flips
+ to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from types import SimpleNamespace
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+ def filter(self, *_args, **_kwargs):
+ return self
+
+
+class _FakeSession:
+ def __init__(self, podcast):
+ self._podcast = podcast
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self._podcast)
+
+ async def commit(self):
+ self.commit_count += 1
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ return None
+
+
+class _FakeSessionMaker:
+ def __init__(self, session: _FakeSession):
+ self._session = session
+
+ def __call__(self):
+ return self._session
+
+
+def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
+ """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
+ inside helpers keeps this fixture cheap."""
+ return SimpleNamespace(
+ id=podcast_id,
+ title="Test Podcast",
+ thread_id=thread_id,
+ status=None,
+ podcast_transcript=None,
+ file_location=None,
+ )
+
+
+@contextlib.asynccontextmanager
+async def _ok_billable_call(**kwargs):
+ """Stand-in for ``billable_call`` that records its kwargs and yields a
+ no-op accumulator-shaped object."""
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+
+
+_CALL_LOG: list[dict[str, Any]] = []
+
+
+@contextlib.asynccontextmanager
+async def _denying_billable_call(**kwargs):
+ from app.services.billable_calls import QuotaInsufficientError
+
+ _CALL_LOG.append(kwargs)
+ raise QuotaInsufficientError(
+ usage_type=kwargs.get("usage_type", "?"),
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield SimpleNamespace() # pragma: no cover — for grammar only
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_call_log():
+ _CALL_LOG.clear()
+ yield
+ _CALL_LOG.clear()
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
+ """Happy path: free billing tier still wraps the graph call so the
+ audit row is recorded. Verifies kwargs threading."""
+ from app.config import config as app_config
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=7, thread_id=99)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ assert search_space_id == 555
+ assert thread_id == 99
+ return user_id, "free", "openrouter/some-free-model"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {
+ "podcast_transcript": [
+ SimpleNamespace(speaker_id=0, dialog="Hi"),
+ SimpleNamespace(speaker_id=1, dialog="Hello"),
+ ],
+ "final_podcast_file_path": "/tmp/podcast.wav",
+ }
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=7,
+ source_content="hello world",
+ search_space_id=555,
+ user_prompt="make it short",
+ )
+
+ assert result["status"] == "ready"
+ assert result["podcast_id"] == 7
+ assert podcast.status == PodcastStatus.READY
+ assert podcast.file_location == "/tmp/podcast.wav"
+
+ assert len(_CALL_LOG) == 1
+ call = _CALL_LOG[0]
+ assert call["user_id"] == user_id
+ assert call["search_space_id"] == 555
+ assert call["billing_tier"] == "free"
+ assert call["base_model"] == "openrouter/some-free-model"
+ assert call["usage_type"] == "podcast_generation"
+ assert (
+ call["quota_reserve_micros_override"]
+ == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
+ )
+ assert call["thread_id"] == 99
+ assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_premium_tier(monkeypatch):
+ """Premium resolution flows through to ``billable_call`` so the
+ reserve/finalize path triggers."""
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast()
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return user_id, "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ await podcast_tasks._generate_content_podcast(
+ podcast_id=7,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert _CALL_LOG[0]["billing_tier"] == "premium"
+ assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
+ """When ``billable_call`` denies the reservation, the graph never
+ runs and the podcast row flips to FAILED with the documented reason
+ code."""
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=8)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=8,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 8,
+ "reason": "premium_quota_exhausted",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+ assert graph_invoked == [] # Graph never ran on denied reservation.
+
+
+@pytest.mark.asyncio
+async def test_resolver_failure_marks_podcast_failed(monkeypatch):
+ """If the resolver raises (e.g. search-space deleted), the task fails
+ cleanly without invoking the graph."""
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=9)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _failing_resolver(sess, search_space_id, *, thread_id=None):
+ raise ValueError("Search space 555 not found")
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=9,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 9,
+ "reason": "billing_resolution_failed",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+ assert graph_invoked == []
diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
new file mode 100644
index 000000000..671f57ae4
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
@@ -0,0 +1,330 @@
+"""Unit tests for video-presentation Celery task billing integration.
+
+Mirrors ``test_podcast_billing.py`` for the video-presentation task.
+Validates the same wrap-graph-in-billable_call pattern and ensures the
+larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
+threaded through.
+
+Coverage:
+
+* Free config: graph runs, ``billable_call`` invoked with the video
+ reserve override.
+* Premium config: same wiring with ``billing_tier='premium'``.
+* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
+* Resolver failure: row → FAILED with ``billing_resolution_failed``.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from types import SimpleNamespace
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+ def filter(self, *_args, **_kwargs):
+ return self
+
+
+class _FakeSession:
+ def __init__(self, video):
+ self._video = video
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self._video)
+
+ async def commit(self):
+ self.commit_count += 1
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ return None
+
+
+class _FakeSessionMaker:
+ def __init__(self, session: _FakeSession):
+ self._session = session
+
+ def __call__(self):
+ return self._session
+
+
+def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=video_id,
+ title="Test Presentation",
+ thread_id=thread_id,
+ status=None,
+ slides=None,
+ scene_codes=None,
+ )
+
+
+_CALL_LOG: list[dict[str, Any]] = []
+
+
+@contextlib.asynccontextmanager
+async def _ok_billable_call(**kwargs):
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+
+
+@contextlib.asynccontextmanager
+async def _denying_billable_call(**kwargs):
+ from app.services.billable_calls import QuotaInsufficientError
+
+ _CALL_LOG.append(kwargs)
+ raise QuotaInsufficientError(
+ usage_type=kwargs.get("usage_type", "?"),
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield SimpleNamespace() # pragma: no cover
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_call_log():
+ _CALL_LOG.clear()
+ yield
+ _CALL_LOG.clear()
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
+ from app.config import config as app_config
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=11, thread_id=99)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ assert search_space_id == 777
+ assert thread_id == 99
+ return user_id, "free", "openrouter/some-free-model"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=11,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result["status"] == "ready"
+ assert result["video_presentation_id"] == 11
+ assert video.status == VideoPresentationStatus.READY
+
+ assert len(_CALL_LOG) == 1
+ call = _CALL_LOG[0]
+ assert call["user_id"] == user_id
+ assert call["search_space_id"] == 777
+ assert call["billing_tier"] == "free"
+ assert call["base_model"] == "openrouter/some-free-model"
+ assert call["usage_type"] == "video_presentation_generation"
+ assert (
+ call["quota_reserve_micros_override"]
+ == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
+ )
+ assert call["thread_id"] == 99
+ assert call["call_details"] == {
+ "video_presentation_id": 11,
+ "title": "Test Presentation",
+ }
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_premium_tier(monkeypatch):
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video()
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return user_id, "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=11,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert _CALL_LOG[0]["billing_tier"] == "premium"
+ assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=12)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(
+ video_presentation_tasks, "billable_call", _denying_billable_call
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=12,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 12,
+ "reason": "premium_quota_exhausted",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+ assert graph_invoked == []
+
+
+@pytest.mark.asyncio
+async def test_resolver_failure_marks_video_failed(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=13)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _failing_resolver(sess, search_space_id, *, thread_id=None):
+ raise ValueError("Search space 777 not found")
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _failing_resolver,
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=13,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 13,
+ "reason": "billing_resolution_failed",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+ assert graph_invoked == []
diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx
index 8d9ed5cb1..3ddd5195f 100644
--- a/surfsense_web/app/(home)/free/page.tsx
+++ b/surfsense_web/app/(home)/free/page.tsx
@@ -127,7 +127,7 @@ const FAQ_ITEMS = [
{
question: "What happens after I use my free tokens?",
answer:
- "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.",
+ "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.",
},
{
question: "Is Claude AI available without login?",
@@ -329,7 +329,7 @@ export default async function FreeHubPage() {
Want More Features?
- Create a free SurfSense account to unlock 3 million tokens, document uploads with
+ Create a free SurfSense account to unlock $5 of premium credit, document uploads with
citations, team collaboration, and integrations with Slack, Google Drive, Notion, and
30+ more tools.
diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx
index 6ad9435bf..6f332be70 100644
--- a/surfsense_web/app/(home)/pricing/page.tsx
+++ b/surfsense_web/app/(home)/pricing/page.tsx
@@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav";
export const metadata: Metadata = {
title: "Pricing | SurfSense - Free AI Search Plans",
description:
- "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.",
+ "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.",
alternates: {
canonical: "https://surfsense.com/pricing",
},
diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
index 3017160e1..0c5662712 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
@@ -8,7 +8,7 @@ import { cn } from "@/lib/utils";
const TABS = [
{ id: "pages", label: "Pages" },
- { id: "tokens", label: "Premium Tokens" },
+ { id: "tokens", label: "Premium Credit" },
] as const;
type TabId = (typeof TABS)[number]["id"];
diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx
index 2b7422f80..cf73b5eba 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx
@@ -28,6 +28,12 @@ type UnifiedPurchase = {
kind: PurchaseKind;
created_at: string;
status: PagePurchaseStatus;
+ /**
+ * Granted units. Interpretation depends on ``kind``:
+ * - ``"pages"`` — integer number of indexed pages.
+ * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00).
+ * The ``Granted`` column formats accordingly.
+ */
granted: number;
amount_total: number | null;
currency: string | null;
@@ -58,7 +64,7 @@ const KIND_META: Record<
iconClass: "text-sky-500",
},
tokens: {
- label: "Premium Tokens",
+ label: "Premium Credit",
icon: Coins,
iconClass: "text-amber-500",
},
@@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase {
kind: "tokens",
created_at: p.created_at,
status: p.status,
- granted: p.tokens_granted,
+ granted: p.credit_micros_granted,
amount_total: p.amount_total,
currency: p.currency,
};
}
+function formatGranted(p: UnifiedPurchase): string {
+ if (p.kind === "tokens") {
+ const dollars = p.granted / 1_000_000;
+ // Premium credit packs are always whole dollars at the moment, but
+ // future fractional grants (refunds, partial top-ups) shouldn't
+ // silently round to "$0".
+ if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`;
+ if (dollars > 0) return `$${dollars.toFixed(3)} of credit`;
+ return "$0 of credit";
+ }
+ return p.granted.toLocaleString();
+}
+
export function PurchaseHistoryContent() {
const results = useQueries({
queries: [
@@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {
No purchases yet
- Your page and premium token purchases will appear here after checkout.
+ Your page and premium credit purchases will appear here after checkout.
);
@@ -177,7 +196,7 @@ export function PurchaseHistoryContent() {
- {p.granted.toLocaleString()}
+ {formatGranted(p)}
{formatAmount(p.amount_total, p.currency)}
diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts
index a59811324..4b6717440 100644
--- a/surfsense_web/atoms/user/user-query.atoms.ts
+++ b/surfsense_web/atoms/user/user-query.atoms.ts
@@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe();
export const currentUserAtom = atomWithQuery(() => {
return {
queryKey: USER_QUERY_KEY,
- // Live-changing numeric fields (pages_*, premium_tokens_*) are now
- // pushed via Zero (queries.user.me()), so /users/me only needs to
- // fire once per session for the static profile fields.
+ // Live-changing numeric fields (pages_*, premium_credit_micros_*)
+ // are now pushed via Zero (queries.user.me()), so /users/me only
+ // needs to fire once per session for the static profile fields.
staleTime: Infinity,
enabled: !!getBearerToken(),
queryFn: userQueryFn,
diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx
index 711bb2fe2..ffb0e4dc8 100644
--- a/surfsense_web/components/assistant-ui/assistant-message.tsx
+++ b/surfsense_web/components/assistant-ui/assistant-message.tsx
@@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string {
});
}
+/**
+ * Format provider USD cost (in micro-USD) for inline display next to a
+ * token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent
+ * costs so a real-but-tiny figure doesn't render as ``$0.000``.
+ */
+function formatTurnCost(micros: number): string {
+ const dollars = micros / 1_000_000;
+ if (dollars >= 1) return `$${dollars.toFixed(2)}`;
+ if (dollars >= 0.01) return `$${dollars.toFixed(3)}`;
+ if (dollars > 0) return "<$0.001";
+ return "$0";
+}
+
const MessageInfoDropdown: FC = () => {
const messageId = useAuiState(({ message }) => message?.id);
const createdAt = useAuiState(({ message }) => message?.createdAt);
@@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => {
{models.length > 0 ? (
models.map(([model, counts]) => {
const { name, icon } = resolveModel(model);
+ const costMicros = counts.cost_micros;
return (
{
{counts.total_tokens.toLocaleString()} tokens
+ {costMicros && costMicros > 0
+ ? ` · ${formatTurnCost(costMicros)}`
+ : ""}
);
@@ -474,6 +491,9 @@ const MessageInfoDropdown: FC = () => {
>
{usage.total_tokens.toLocaleString()} tokens
+ {usage.cost_micros && usage.cost_micros > 0
+ ? ` · ${formatTurnCost(usage.cost_micros)}`
+ : ""}
)}
diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx
index b3f71ab21..dd80bcac3 100644
--- a/surfsense_web/components/assistant-ui/token-usage-context.tsx
+++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx
@@ -13,13 +13,30 @@ export interface TokenUsageData {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
+ /**
+ * Total provider USD cost for this assistant turn, in micro-USD
+ * (1_000_000 = $1.00). Populated from LiteLLM's response_cost on
+ * the backend. Optional because pre-cost-credits messages persisted
+ * before the migration won't have it.
+ */
+ cost_micros?: number;
usage?: Record<
string,
- { prompt_tokens: number; completion_tokens: number; total_tokens: number }
+ {
+ prompt_tokens: number;
+ completion_tokens: number;
+ total_tokens: number;
+ cost_micros?: number;
+ }
>;
model_breakdown?: Record<
string,
- { prompt_tokens: number; completion_tokens: number; total_tokens: number }
+ {
+ prompt_tokens: number;
+ completion_tokens: number;
+ total_tokens: number;
+ cost_micros?: number;
+ }
>;
}
diff --git a/surfsense_web/components/free-chat/quota-warning-banner.tsx b/surfsense_web/components/free-chat/quota-warning-banner.tsx
index 3bfedf1b3..e013a64a8 100644
--- a/surfsense_web/components/free-chat/quota-warning-banner.tsx
+++ b/surfsense_web/components/free-chat/quota-warning-banner.tsx
@@ -40,7 +40,7 @@ export function QuotaWarningBanner({
You've used all {limit.toLocaleString()} free tokens. Create a free account to
- get 3 million tokens and access to all models.
+ get $5 of premium credit and access to all models.
Create an account
{" "}
- for 5M free tokens.
+ for $5 of premium credit.
{
- if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(1)}M`;
- if (n >= 1_000) return `${(n / 1_000).toFixed(0)}K`;
- return n.toLocaleString();
+ const formatUsd = (micros: number) => {
+ const dollars = micros / 1_000_000;
+ if (dollars >= 100) return `$${dollars.toFixed(0)}`;
+ if (dollars >= 1) return `$${dollars.toFixed(2)}`;
+ // Sub-dollar balances need extra precision so the bar still tells the
+ // user what's left ("$0.04 of credit") instead of rounding to "$0".
+ if (dollars > 0) return `$${dollars.toFixed(3)}`;
+ return "$0";
};
return (
- {formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens
+ {formatUsd(me.premiumCreditMicrosUsed)} / {formatUsd(me.premiumCreditMicrosLimit)} of
+ credit
{usagePercentage.toFixed(0)}%
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx
index 175cae4ab..127b79167 100644
--- a/surfsense_web/components/pricing/pricing-section.tsx
+++ b/surfsense_web/components/pricing/pricing-section.tsx
@@ -12,11 +12,11 @@ const demoPlans = [
price: "0",
yearlyPrice: "0",
period: "",
- billingText: "500 pages + 3M premium tokens included",
+ billingText: "500 pages + $5 of premium credit included",
features: [
"Self Hostable",
"500 pages included to start",
- "3 million premium tokens to start",
+ "$5 of premium credit to start, billed at provider cost",
"Includes access to OpenAI text, audio and image models",
"Realtime Collaborative Group Chats with teammates",
"Community support on Discord",
@@ -35,7 +35,7 @@ const demoPlans = [
features: [
"Everything in Free",
"Buy 1,000-page packs at $1 each",
- "Buy 1M premium token packs at $1 each",
+ "Top up premium credit at $1 per $1 of credit, billed at provider cost",
"Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter",
"Priority support on Discord",
],
@@ -129,27 +129,27 @@ const faqData: FAQSection[] = [
],
},
{
- title: "Premium Tokens",
+ title: "Premium Credit",
items: [
{
- question: 'What are "premium tokens"?',
+ question: 'What is "premium credit"?',
answer:
- "Premium tokens are the billing unit for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request consumes tokens based on the length of your conversation. Non-premium models (such as free-tier models available without login) do not consume premium tokens.",
+ "Premium credit is your USD balance for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request debits the actual USD cost the provider charges, so cheap and expensive models bill proportionally. Non-premium models (such as the free-tier models available without login) don't touch your premium credit.",
},
{
- question: "How many premium tokens do I get for free?",
+ question: "How much premium credit do I get for free?",
answer:
- "Every registered SurfSense account starts with 3 million premium tokens at no cost. Anonymous users (no login) get 500,000 free tokens across all models. Once your free tokens are used up, you can purchase more at any time.",
+ "Every registered SurfSense account starts with $5 of premium credit at no cost. Anonymous users (no login) get 500,000 free tokens across all free models. Once your free credit runs out, you can top up at any time.",
},
{
- question: "How does purchasing premium tokens work?",
+ question: "How does buying premium credit work?",
answer:
- "Just like pages, there's no subscription. You buy 1-million-token packs at $1 each whenever you need more. Purchased tokens are added to your account immediately. You can buy up to 100 packs at a time.",
+ "Just like pages, there's no subscription. Top-ups buy $1 of credit for $1 — every cent you pay is spent at provider cost, no markup. Purchased credit is added to your account immediately. You can buy up to $100 at a time.",
},
{
- question: "What happens if I run out of premium tokens?",
+ question: "What happens if I run out of premium credit?",
answer:
- "When your premium token balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you purchase more tokens. You can always switch to non-premium models which don't consume premium tokens.",
+ "When your premium credit balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you top up. You can always switch to non-premium models, which don't touch your premium credit.",
},
],
},
@@ -157,9 +157,9 @@ const faqData: FAQSection[] = [
title: "Self-Hosting",
items: [
{
- question: "Can I self-host SurfSense with unlimited pages and tokens?",
+ question: "Can I self-host SurfSense with unlimited pages and credit?",
answer:
- "Yes! When self-hosting, you have full control over your page and token limits. The default self-hosted setup gives you effectively unlimited pages and tokens, so you can index as much data and use as many AI queries as your infrastructure supports.",
+ "Yes! When self-hosting, you have full control over your page and premium-credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credit, so you can index as much data and use as many AI queries as your infrastructure supports.",
},
],
},
@@ -250,8 +250,8 @@ function PricingFAQ() {
Frequently Asked Questions
- Everything you need to know about SurfSense pages, premium tokens, and billing. Can't
- find what you need? Reach out at{" "}
+ Everything you need to know about SurfSense pages, premium credit, and billing.
+ Can't find what you need? Reach out at{" "}
rohan@surfsense.com
@@ -335,7 +335,7 @@ function PricingBasic() {
>
diff --git a/surfsense_web/components/settings/buy-tokens-content.tsx b/surfsense_web/components/settings/buy-tokens-content.tsx
index e7fac4255..79a1b4943 100644
--- a/surfsense_web/components/settings/buy-tokens-content.tsx
+++ b/surfsense_web/components/settings/buy-tokens-content.tsx
@@ -14,10 +14,23 @@ import { AppError } from "@/lib/error";
import { cn } from "@/lib/utils";
import { queries } from "@/zero/queries";
-const TOKEN_PACK_SIZE = 1_000_000;
+// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the
+// backend. Premium turns are debited at the actual provider cost
+// reported by LiteLLM, so $1 of credit always buys $1 of provider
+// usage at cost.
+const CREDIT_PER_PACK_MICROS = 1_000_000;
const PRICE_PER_PACK_USD = 1;
const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const;
+const formatUsd = (micros: number, options?: { compact?: boolean }) => {
+ const dollars = micros / 1_000_000;
+ if (options?.compact && dollars >= 1) return `$${dollars.toFixed(2)}`;
+ if (dollars >= 100) return `$${dollars.toFixed(0)}`;
+ if (dollars >= 1) return `$${dollars.toFixed(2)}`;
+ if (dollars > 0) return `$${dollars.toFixed(3)}`;
+ return "$0";
+};
+
export function BuyTokensContent() {
const params = useParams();
const searchSpaceId = Number(params?.search_space_id);
@@ -29,7 +42,7 @@ export function BuyTokensContent() {
queryFn: () => stripeApiService.getTokenStatus(),
});
- // Live per-user usage via Zero.
+ // Live per-user balance via Zero.
const [me] = useZeroQuery(queries.user.me({}));
const purchaseMutation = useMutation({
@@ -46,44 +59,46 @@ export function BuyTokensContent() {
},
});
- const totalTokens = quantity * TOKEN_PACK_SIZE;
+ const totalCreditMicros = quantity * CREDIT_PER_PACK_MICROS;
const totalPrice = quantity * PRICE_PER_PACK_USD;
if (tokenStatus && !tokenStatus.token_buying_enabled) {
return (
-
Buy Premium Tokens
+
Buy Premium Credit
- Token purchases are temporarily unavailable.
+ Credit purchases are temporarily unavailable.
);
}
- const used = me?.premiumTokensUsed ?? 0;
- const limit = me?.premiumTokensLimit ?? 0;
- // Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)).
+ const used = me?.premiumCreditMicrosUsed ?? 0;
+ const limit = me?.premiumCreditMicrosLimit ?? 0;
+ // Mirrors the backend formula in stripe_routes.py (max(0, limit - used)).
const remaining = Math.max(0, limit - used);
const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
return (
-
Buy Premium Tokens
-
$1 per 1M tokens, pay as you go
+
Buy Premium Credit
+
+ $1 buys $1 of credit, billed at provider cost
+
{me && (
- {used.toLocaleString()} / {limit.toLocaleString()} premium tokens
+ {formatUsd(used)} / {formatUsd(limit)} of credit
{usagePercentage.toFixed(0)}%
- {remaining.toLocaleString()} tokens remaining
+ {formatUsd(remaining)} of credit remaining
)}
@@ -99,7 +114,7 @@ export function BuyTokensContent() {
- {(totalTokens / 1_000_000).toFixed(0)}M tokens
+ ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit
- {m}M
+ ${m}
))}
- {(totalTokens / 1_000_000).toFixed(0)}M premium tokens
+ ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit
${totalPrice}
@@ -149,7 +164,7 @@ export function BuyTokensContent() {
>
) : (
<>
- Buy {(totalTokens / 1_000_000).toFixed(0)}M Tokens for ${totalPrice}
+ Buy ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit for ${totalPrice}
>
)}
diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx
index f5f128f80..ced97464e 100644
--- a/surfsense_web/components/settings/image-model-manager.tsx
+++ b/surfsense_web/components/settings/image-model-manager.tsx
@@ -190,7 +190,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
? "model"
: "models"}
{" "}
- available from your administrator.
+ available from your administrator.{" "}
+ {(() => {
+ const nonAuto = globalConfigs.filter(
+ (g) => !("is_auto_mode" in g && g.is_auto_mode)
+ );
+ const premium = nonAuto.filter(
+ (g) =>
+ "billing_tier" in g &&
+ (g as { billing_tier?: string }).billing_tier === "premium"
+ ).length;
+ const free = nonAuto.length - premium;
+ if (premium > 0 && free > 0) {
+ return `${premium} premium, ${free} free.`;
+ }
+ if (premium > 0) {
+ return `All ${premium} premium — debits your shared credit pool.`;
+ }
+ return `All ${free} free.`;
+ })()}
diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx
index e21dc9028..a2eb6a22e 100644
--- a/surfsense_web/components/settings/llm-role-manager.tsx
+++ b/surfsense_web/components/settings/llm-role-manager.tsx
@@ -371,6 +371,17 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
{roleGlobalConfigs.map((config) => {
const isAuto = "is_auto_mode" in config && config.is_auto_mode;
+ // Read billing_tier from the global config; default to "free"
+ // for legacy YAMLs / Auto stub. Premium gets a purple badge,
+ // free gets an emerald one — same palette as the chat
+ // model selector so the meaning is consistent across
+ // surfaces (issues E, H).
+ const billingTier =
+ ("billing_tier" in config &&
+ typeof config.billing_tier === "string" &&
+ config.billing_tier) ||
+ "free";
+ const isPremium = billingTier === "premium";
return (
{config.name}
- {isAuto && (
+ {isAuto ? (
Recommended
+ ) : isPremium ? (
+
+ Premium
+
+ ) : (
+
+ Free
+
)}
diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx
index 8abfa4774..886d71008 100644
--- a/surfsense_web/components/settings/vision-model-manager.tsx
+++ b/surfsense_web/components/settings/vision-model-manager.tsx
@@ -191,7 +191,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
? "model"
: "models"}
{" "}
- available from your administrator.
+ available from your administrator.{" "}
+ {(() => {
+ const nonAuto = globalConfigs.filter(
+ (g) => !("is_auto_mode" in g && g.is_auto_mode)
+ );
+ const premium = nonAuto.filter(
+ (g) =>
+ "billing_tier" in g &&
+ (g as { billing_tier?: string }).billing_tier === "premium"
+ ).length;
+ const free = nonAuto.length - premium;
+ if (premium > 0 && free > 0) {
+ return `${premium} premium, ${free} free.`;
+ }
+ if (premium > 0) {
+ return `All ${premium} premium — debits your shared credit pool.`;
+ }
+ return `All ${free} free.`;
+ })()}
diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx
index fad64fa9f..790e5c00e 100644
--- a/surfsense_web/contexts/login-gate.tsx
+++ b/surfsense_web/contexts/login-gate.tsx
@@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) {
Create a free account to {feature}
- Get 3 million tokens, save chat history, upload documents, use all AI tools, and
- connect 30+ integrations.
+ Get $5 of premium credit, save chat history, upload documents, use all AI tools,
+ and connect 30+ integrations.
diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts
index ecffc573e..2d6b70eda 100644
--- a/surfsense_web/contracts/types/new-llm-config.types.ts
+++ b/surfsense_web/contracts/types/new-llm-config.types.ts
@@ -258,6 +258,8 @@ export const globalImageGenConfig = z.object({
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
+ billing_tier: z.string().default("free"),
+ quota_reserve_micros: z.number().nullable().optional(),
});
export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig);
@@ -338,6 +340,10 @@ export const globalVisionLLMConfig = z.object({
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
+ billing_tier: z.string().default("free"),
+ quota_reserve_tokens: z.number().nullable().optional(),
+ input_cost_per_token: z.number().nullable().optional(),
+ output_cost_per_token: z.number().nullable().optional(),
});
export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig);
diff --git a/surfsense_web/contracts/types/stripe.types.ts b/surfsense_web/contracts/types/stripe.types.ts
index c8b017044..251f7a176 100644
--- a/surfsense_web/contracts/types/stripe.types.ts
+++ b/surfsense_web/contracts/types/stripe.types.ts
@@ -32,7 +32,7 @@ export const getPagePurchasesResponse = z.object({
purchases: z.array(pagePurchase),
});
-// Premium token purchases
+// Premium credit purchases
export const createTokenCheckoutSessionRequest = z.object({
quantity: z.number().int().min(1).max(100),
search_space_id: z.number().int().min(1),
@@ -42,11 +42,16 @@ export const createTokenCheckoutSessionResponse = z.object({
checkout_url: z.string(),
});
+// Premium credit balance + purchase records.
+//
+// The unit is integer micro-USD (1_000_000 == $1.00). The schema names
+// kept the ``Token`` prefix for API back-compat with pinned clients;
+// the field names below are authoritative.
export const tokenStripeStatusResponse = z.object({
token_buying_enabled: z.boolean(),
- premium_tokens_used: z.number().default(0),
- premium_tokens_limit: z.number().default(0),
- premium_tokens_remaining: z.number().default(0),
+ premium_credit_micros_used: z.number().default(0),
+ premium_credit_micros_limit: z.number().default(0),
+ premium_credit_micros_remaining: z.number().default(0),
});
export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum;
@@ -56,7 +61,7 @@ export const tokenPurchase = z.object({
stripe_checkout_session_id: z.string(),
stripe_payment_intent_id: z.string().nullable(),
quantity: z.number(),
- tokens_granted: z.number(),
+ credit_micros_granted: z.number(),
amount_total: z.number().nullable(),
currency: z.string().nullable(),
status: tokenPurchaseStatusEnum,
diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts
index 95d9848f2..1c67d59a1 100644
--- a/surfsense_web/lib/chat/chat-error-classifier.ts
+++ b/surfsense_web/lib/chat/chat-error-classifier.ts
@@ -41,7 +41,7 @@ export interface RawChatErrorInput {
}
export const PREMIUM_QUOTA_ASSISTANT_MESSAGE =
- "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue.";
+ "I can’t continue with the current premium model because your premium credit is exhausted. Switch to a free model or top up your credit to continue.";
function getErrorMessage(error: unknown): string {
if (error instanceof Error) return error.message;
diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts
index 80e7bffbe..6df56f0ce 100644
--- a/surfsense_web/lib/chat/streaming-state.ts
+++ b/surfsense_web/lib/chat/streaming-state.ts
@@ -541,16 +541,23 @@ export type SSEEvent =
data: {
usage: Record<
string,
- { prompt_tokens: number; completion_tokens: number; total_tokens: number }
+ {
+ prompt_tokens: number;
+ completion_tokens: number;
+ total_tokens: number;
+ cost_micros?: number;
+ }
>;
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
+ cost_micros?: number;
call_details: Array<{
model: string;
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
+ cost_micros?: number;
}>;
};
}
diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts
index fc970c26e..7fec60a23 100644
--- a/surfsense_web/lib/chat/thread-persistence.ts
+++ b/surfsense_web/lib/chat/thread-persistence.ts
@@ -30,9 +30,20 @@ export interface TokenUsageSummary {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
+ /**
+ * Total provider USD cost for this assistant turn, in micro-USD
+ * (1_000_000 = $1.00). Optional because rows persisted before the
+ * cost-credits migration won't have it.
+ */
+ cost_micros?: number;
model_breakdown?: Record<
string,
- { prompt_tokens: number; completion_tokens: number; total_tokens: number }
+ {
+ prompt_tokens: number;
+ completion_tokens: number;
+ total_tokens: number;
+ cost_micros?: number;
+ }
> | null;
}
diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts
index 0e6234db5..f483fa9b4 100644
--- a/surfsense_web/zero/schema/user.ts
+++ b/surfsense_web/zero/schema/user.ts
@@ -1,11 +1,20 @@
import { number, string, table } from "@rocicorp/zero";
+/**
+ * Live-meter slice of the ``user`` table replicated through Zero.
+ *
+ * ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored
+ * as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M
+ * when displaying. Sensitive fields (email, hashed_password, oauth, etc.)
+ * are intentionally omitted via the Postgres column-list publication so
+ * they never enter WAL replication.
+ */
export const userTable = table("user")
.columns({
id: string(),
pagesLimit: number().from("pages_limit"),
pagesUsed: number().from("pages_used"),
- premiumTokensLimit: number().from("premium_tokens_limit"),
- premiumTokensUsed: number().from("premium_tokens_used"),
+ premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"),
+ premiumCreditMicrosUsed: number().from("premium_credit_micros_used"),
})
.primaryKey("id");
From 47b2994ec76c88b45c1cae55116372be87368e9f Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 19:18:53 -0700
Subject: [PATCH 03/12] feat: fixed vision/image provider specific errors and
fixed podcast/video streaming
---
.../app/agents/new_chat/llm_config.py | 132 +++--
.../agents/new_chat/tools/generate_image.py | 43 +-
surfsense_backend/app/config/__init__.py | 26 +
.../app/routes/image_generation_routes.py | 49 +-
.../app/routes/new_llm_config_routes.py | 73 ++-
.../app/routes/vision_llm_routes.py | 8 +-
.../app/schemas/image_generation.py | 9 +
.../app/schemas/new_llm_config.py | 36 ++
surfsense_backend/app/schemas/vision_llm.py | 9 +
.../app/services/auto_model_pin_service.py | 87 ++-
.../app/services/billable_calls.py | 258 ++++++--
.../app/services/image_gen_router_service.py | 21 +-
.../app/services/llm_router_service.py | 2 -
surfsense_backend/app/services/llm_service.py | 35 +-
.../openrouter_integration_service.py | 38 +-
.../app/services/provider_api_base.py | 1 -
.../app/services/provider_capabilities.py | 280 +++++++++
.../app/tasks/celery_tasks/__init__.py | 100 +++-
.../app/tasks/celery_tasks/connector_tasks.py | 197 ++-----
.../celery_tasks/document_reindex_tasks.py | 12 +-
.../app/tasks/celery_tasks/document_tasks.py | 188 ++----
.../app/tasks/celery_tasks/obsidian_tasks.py | 20 +-
.../app/tasks/celery_tasks/podcast_tasks.py | 50 +-
.../celery_tasks/schedule_checker_task.py | 12 +-
.../stale_notification_cleanup_task.py | 14 +-
.../stripe_reconciliation_task.py | 19 +-
.../celery_tasks/video_presentation_tasks.py | 58 +-
.../app/tasks/chat/stream_new_chat.py | 135 ++++-
.../scripts/verify_chat_image_capability.py | 558 ++++++++++++++++++
.../routes/test_byok_supports_image_input.py | 110 ++++
.../routes/test_global_configs_is_premium.py | 184 ++++++
...t_global_new_llm_configs_supports_image.py | 106 ++++
.../services/test_auto_pin_image_aware.py | 286 +++++++++
.../tests/unit/services/test_billable_call.py | 131 +++-
.../test_image_gen_api_base_defense.py | 177 ++++++
.../test_openrouter_integration_service.py | 8 +
.../unit/services/test_provider_api_base.py | 107 ++++
.../services/test_provider_capabilities.py | 244 ++++++++
.../services/test_supports_image_input.py | 281 +++++++++
.../test_vision_llm_api_base_defense.py | 89 +++
.../unit/tasks/test_celery_async_runner.py | 318 ++++++++++
.../tests/unit/tasks/test_podcast_billing.py | 67 ++-
.../test_stream_new_chat_image_safety_net.py | 119 ++++
.../tasks/test_video_presentation_billing.py | 70 ++-
.../assistant-ui/assistant-message.tsx | 4 +-
.../components/new-chat/model-selector.tsx | 46 +-
.../components/pricing/pricing-section.tsx | 4 +-
.../settings/image-model-manager.tsx | 73 ++-
.../settings/more-pages-content.tsx | 4 +-
.../settings/vision-model-manager.tsx | 73 ++-
.../components/tool-ui/generate-podcast.tsx | 21 +-
surfsense_web/contexts/login-gate.tsx | 4 +-
.../contracts/types/new-llm-config.types.ts | 30 +
surfsense_web/next.config.ts | 6 +
54 files changed, 4469 insertions(+), 563 deletions(-)
create mode 100644 surfsense_backend/app/services/provider_capabilities.py
create mode 100644 surfsense_backend/scripts/verify_chat_image_capability.py
create mode 100644 surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py
create mode 100644 surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py
create mode 100644 surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py
create mode 100644 surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py
create mode 100644 surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py
create mode 100644 surfsense_backend/tests/unit/services/test_provider_api_base.py
create mode 100644 surfsense_backend/tests/unit/services/test_provider_capabilities.py
create mode 100644 surfsense_backend/tests/unit/services/test_supports_image_input.py
create mode 100644 surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py
create mode 100644 surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
create mode 100644 surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py
diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py
index 99bb719f6..bc37bf1c4 100644
--- a/surfsense_backend/app/agents/new_chat/llm_config.py
+++ b/surfsense_backend/app/agents/new_chat/llm_config.py
@@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk
-# Provider mapping for LiteLLM model string construction
-PROVIDER_MAP = {
- "OPENAI": "openai",
- "ANTHROPIC": "anthropic",
- "GROQ": "groq",
- "COHERE": "cohere",
- "GOOGLE": "gemini",
- "OLLAMA": "ollama_chat",
- "MISTRAL": "mistral",
- "AZURE_OPENAI": "azure",
- "OPENROUTER": "openrouter",
- "XAI": "xai",
- "BEDROCK": "bedrock",
- "VERTEX_AI": "vertex_ai",
- "TOGETHER_AI": "together_ai",
- "FIREWORKS_AI": "fireworks_ai",
- "DEEPSEEK": "openai",
- "ALIBABA_QWEN": "openai",
- "MOONSHOT": "openai",
- "ZHIPU": "openai",
- "GITHUB_MODELS": "github",
- "REPLICATE": "replicate",
- "PERPLEXITY": "perplexity",
- "ANYSCALE": "anyscale",
- "DEEPINFRA": "deepinfra",
- "CEREBRAS": "cerebras",
- "SAMBANOVA": "sambanova",
- "AI21": "ai21",
- "CLOUDFLARE": "cloudflare",
- "DATABRICKS": "databricks",
- "COMETAPI": "cometapi",
- "HUGGINGFACE": "huggingface",
- "MINIMAX": "openai",
- "CUSTOM": "custom",
-}
+# Provider mapping for LiteLLM model string construction.
+#
+# Single source of truth lives in
+# :mod:`app.services.provider_capabilities` so the YAML loader (which
+# runs during ``app.config`` class-body init) can resolve provider
+# prefixes without dragging the agent / tools tree into module load
+# order. Re-exported here under the historical ``PROVIDER_MAP`` name
+# so existing callers (``llm_router_service``, ``image_gen_router_service``,
+# tests) keep working unchanged.
+from app.services.provider_capabilities import ( # noqa: E402
+ _PROVIDER_PREFIX_MAP as PROVIDER_MAP,
+)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@@ -178,6 +155,17 @@ class AgentConfig:
anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None
+ # Capability flag: best-effort True for the chat selector / catalog.
+ # Resolved via :func:`provider_capabilities.derive_supports_image_input`
+ # which prefers OpenRouter's ``architecture.input_modalities`` and
+ # otherwise consults LiteLLM's authoritative model map. Default True
+ # is the conservative-allow stance — the streaming-task safety net
+ # (``is_known_text_only_chat_model``) is the *only* place a False
+ # actually blocks a request. Setting this to False here without an
+ # authoritative source would silently hide vision-capable models
+ # (the regression we're fixing).
+ supports_image_input: bool = True
+
@classmethod
def from_auto_mode(cls) -> "AgentConfig":
"""
@@ -203,6 +191,12 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
+ # Auto routes across the configured pool, which usually
+ # contains at least one vision-capable deployment; the router
+ # will surface a 404 from a non-vision deployment as a normal
+ # ``allowed_fails`` event and fail over rather than blocking
+ # the request outright.
+ supports_image_input=True,
)
@classmethod
@@ -216,10 +210,24 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
- return cls(
- provider=config.provider.value
+ # Lazy import to avoid pulling provider_capabilities (and its
+ # transitive litellm import) into module-init order.
+ from app.services.provider_capabilities import derive_supports_image_input
+
+ provider_value = (
+ config.provider.value
if hasattr(config.provider, "value")
- else str(config.provider),
+ else str(config.provider)
+ )
+ litellm_params = config.litellm_params or {}
+ base_model = (
+ litellm_params.get("base_model")
+ if isinstance(litellm_params, dict)
+ else None
+ )
+
+ return cls(
+ provider=provider_value,
model_name=config.model_name,
api_key=config.api_key,
api_base=config.api_base,
@@ -235,6 +243,16 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
+ # BYOK rows have no operator-curated capability flag, so we
+ # ask LiteLLM (default-allow on unknown). The streaming
+ # safety net still blocks if the model is *explicitly*
+ # marked text-only.
+ supports_image_input=derive_supports_image_input(
+ provider=provider_value,
+ model_name=config.model_name,
+ base_model=base_model,
+ custom_provider=config.custom_provider,
+ ),
)
@classmethod
@@ -253,15 +271,46 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
+ # Lazy import to avoid pulling provider_capabilities (and its
+ # transitive litellm import) into module-init order.
+ from app.services.provider_capabilities import derive_supports_image_input
+
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "")
+ provider = yaml_config.get("provider", "").upper()
+ model_name = yaml_config.get("model_name", "")
+ custom_provider = yaml_config.get("custom_provider")
+ litellm_params = yaml_config.get("litellm_params") or {}
+ base_model = (
+ litellm_params.get("base_model")
+ if isinstance(litellm_params, dict)
+ else None
+ )
+
+ # Explicit YAML override wins; otherwise derive from LiteLLM /
+ # OpenRouter modalities. The YAML loader already populates this
+ # field, but this method is also called from
+ # ``load_global_llm_config_by_id``'s file fallback (hot reload),
+ # so we re-derive here for safety. The bool() coercion preserves
+ # the loader's behaviour for explicit ``true`` / ``false``
+ # strings that PyYAML may surface.
+ if "supports_image_input" in yaml_config:
+ supports_image_input = bool(yaml_config.get("supports_image_input"))
+ else:
+ supports_image_input = derive_supports_image_input(
+ provider=provider,
+ model_name=model_name,
+ base_model=base_model,
+ custom_provider=custom_provider,
+ )
+
return cls(
- provider=yaml_config.get("provider", "").upper(),
- model_name=yaml_config.get("model_name", ""),
+ provider=provider,
+ model_name=model_name,
api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"),
- custom_provider=yaml_config.get("custom_provider"),
+ custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None,
@@ -276,6 +325,7 @@ class AgentConfig:
is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
+ supports_image_input=supports_image_input,
)
diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py
index 3803fa39c..9e287ac51 100644
--- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py
+++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py
@@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
+from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
@@ -49,12 +50,16 @@ _PROVIDER_MAP = {
}
+def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
+ if custom_provider:
+ return custom_provider
+ return _PROVIDER_MAP.get(provider.upper(), provider.lower())
+
+
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
- if custom_provider:
- return f"{custom_provider}/{model_name}"
- prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
+ prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{prefix}/{model_name}"
@@ -146,14 +151,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
- model_string = _build_model_string(
- cfg.get("provider", ""),
- cfg["model_name"],
- cfg.get("custom_provider"),
+ provider_prefix = _resolve_provider_prefix(
+ cfg.get("provider", ""), cfg.get("custom_provider")
)
+ model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
- if cfg.get("api_base"):
- gen_kwargs["api_base"] = cfg["api_base"]
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=provider_prefix,
+ config_api_base=cfg.get("api_base"),
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@@ -175,14 +184,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
- model_string = _build_model_string(
- db_cfg.provider.value,
- db_cfg.model_name,
- db_cfg.custom_provider,
+ provider_prefix = _resolve_provider_prefix(
+ db_cfg.provider.value, db_cfg.custom_provider
)
+ model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
- if db_cfg.api_base:
- gen_kwargs["api_base"] = db_cfg.api_base
+ api_base = resolve_api_base(
+ provider=db_cfg.provider.value,
+ provider_prefix=provider_prefix,
+ config_api_base=db_cfg.api_base,
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py
index 2aeeafb34..97b4cf509 100644
--- a/surfsense_backend/app/config/__init__.py
+++ b/surfsense_backend/app/config/__init__.py
@@ -47,11 +47,37 @@ def load_global_llm_configs():
data = yaml.safe_load(f)
configs = data.get("global_llm_configs", [])
+ # Lazy import keeps the `app.config` -> `app.services` edge one-way
+ # and matches the `provider_api_base` pattern used elsewhere.
+ from app.services.provider_capabilities import derive_supports_image_input
+
seen_slugs: dict[str, int] = {}
for cfg in configs:
cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False)
+ # Capability flag: explicit YAML override always wins. When the
+ # operator has not annotated the model, defer to LiteLLM's
+ # authoritative model map (`supports_vision`) which already
+ # knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
+ # vision-capable. Unknown / unmapped models default-allow so
+ # we don't lock the user out of a freshly added third-party
+ # entry; the streaming-task safety net (driven by
+ # `is_known_text_only_chat_model`) is the only place a False
+ # actually blocks a request.
+ if "supports_image_input" not in cfg:
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = (
+ litellm_params.get("base_model")
+ if isinstance(litellm_params, dict)
+ else None
+ )
+ cfg["supports_image_input"] = derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"]
diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py
index 34ed80207..018234ad5 100644
--- a/surfsense_backend/app/routes/image_generation_routes.py
+++ b/surfsense_backend/app/routes/image_generation_routes.py
@@ -46,6 +46,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
+from app.services.provider_api_base import resolve_api_base
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None
+def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
+ """Resolve the LiteLLM provider prefix used in model strings."""
+ if custom_provider:
+ return custom_provider
+ return _PROVIDER_MAP.get(provider.upper(), provider.lower())
+
+
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
- if custom_provider:
- return f"{custom_provider}/{model_name}"
- prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
- return f"{prefix}/{model_name}"
+ return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
async def _resolve_billing_for_image_gen(
@@ -187,12 +192,18 @@ async def _execute_image_generation(
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
- model_string = _build_model_string(
- cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
+ provider_prefix = _resolve_provider_prefix(
+ cfg.get("provider", ""), cfg.get("custom_provider")
)
+ model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
- if cfg.get("api_base"):
- gen_kwargs["api_base"] = cfg["api_base"]
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=provider_prefix,
+ config_api_base=cfg.get("api_base"),
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@@ -214,12 +225,18 @@ async def _execute_image_generation(
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
- model_string = _build_model_string(
- db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
+ provider_prefix = _resolve_provider_prefix(
+ db_cfg.provider.value, db_cfg.custom_provider
)
+ model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
- if db_cfg.api_base:
- gen_kwargs["api_base"] = db_cfg.api_base
+ api_base = resolve_api_base(
+ provider=db_cfg.provider.value,
+ provider_prefix=provider_prefix,
+ config_api_base=db_cfg.api_base,
+ )
+ if api_base:
+ gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
@@ -277,10 +294,12 @@ async def get_global_image_gen_configs(
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
+ "is_premium": False,
}
)
for cfg in global_configs:
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@@ -293,7 +312,11 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
- "billing_tier": cfg.get("billing_tier", "free"),
+ "billing_tier": billing_tier,
+ # Mirror chat (``new_llm_config_routes``) so the new-chat
+ # selector's premium badge logic keys off the same
+ # field across chat / image / vision tabs.
+ "is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py
index 20779a309..e090a1a7c 100644
--- a/surfsense_backend/app/routes/new_llm_config_routes.py
+++ b/surfsense_backend/app/routes/new_llm_config_routes.py
@@ -29,6 +29,7 @@ from app.schemas import (
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
+from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user
from app.utils.rbac import check_permission
@@ -36,6 +37,39 @@ router = APIRouter()
logger = logging.getLogger(__name__)
+def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
+ """Augment a BYOK chat config row with the derived ``supports_image_input``.
+
+ There is no DB column for ``supports_image_input`` — the value is
+ resolved at the API boundary from LiteLLM's authoritative model map
+ (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
+ the response shape consistent across list / detail / create / update
+ endpoints without having to remember to set the field at every call
+ site.
+ """
+ provider_value = (
+ config.provider.value
+ if hasattr(config.provider, "value")
+ else str(config.provider)
+ )
+ litellm_params = config.litellm_params or {}
+ base_model = (
+ litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
+ )
+ supports_image_input = derive_supports_image_input(
+ provider=provider_value,
+ model_name=config.model_name,
+ base_model=base_model,
+ custom_provider=config.custom_provider,
+ )
+ # ``model_validate`` runs the Pydantic conversion using the ORM
+ # attribute access path enabled by ``ConfigDict(from_attributes=True)``,
+ # then we layer the derived field on. ``model_copy(update=...)`` keeps
+ # the surface immutable from the caller's perspective.
+ base_read = NewLLMConfigRead.model_validate(config)
+ return base_read.model_copy(update={"supports_image_input": supports_image_input})
+
+
# =============================================================================
# Global Configs Routes
# =============================================================================
@@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
+ # Auto routes across the configured pool, which usually
+ # includes at least one vision-capable deployment, so
+ # treat Auto as image-capable. The router itself will
+ # still pick a vision-capable deployment for messages
+ # carrying image_url blocks (LiteLLM Router falls back
+ # on ``404`` per its ``allowed_fails`` policy).
+ "supports_image_input": True,
}
)
# Add individual global configs
for cfg in global_configs:
+ # Capability resolution: explicit value (YAML override or OR
+ # `_supports_image_input(model)` payload baked in by the
+ # OpenRouter integration service) wins. Fall back to the
+ # LiteLLM-driven helper which default-allows on unknown so
+ # we don't hide vision-capable models that happen to lack a
+ # YAML annotation. The streaming task safety net is the
+ # only place a False ever blocks.
+ if "supports_image_input" in cfg:
+ supports_image_input = bool(cfg.get("supports_image_input"))
+ else:
+ cfg_litellm_params = cfg.get("litellm_params") or {}
+ cfg_base_model = (
+ cfg_litellm_params.get("base_model")
+ if isinstance(cfg_litellm_params, dict)
+ else None
+ )
+ supports_image_input = derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=cfg_base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
@@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
+ "supports_image_input": supports_image_input,
}
safe_configs.append(safe_config)
@@ -171,7 +236,7 @@ async def create_new_llm_config(
await session.commit()
await session.refresh(db_config)
- return db_config
+ return _serialize_byok_config(db_config)
except HTTPException:
raise
@@ -213,7 +278,7 @@ async def list_new_llm_configs(
.limit(limit)
)
- return result.scalars().all()
+ return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException:
raise
@@ -268,7 +333,7 @@ async def get_new_llm_config(
"You don't have permission to view LLM configurations in this search space",
)
- return config
+ return _serialize_byok_config(config)
except HTTPException:
raise
@@ -360,7 +425,7 @@ async def update_new_llm_config(
await session.commit()
await session.refresh(config)
- return config
+ return _serialize_byok_config(config)
except HTTPException:
raise
diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py
index 4f7e9f725..e4f08f604 100644
--- a/surfsense_backend/app/routes/vision_llm_routes.py
+++ b/surfsense_backend/app/routes/vision_llm_routes.py
@@ -85,10 +85,12 @@ async def get_global_vision_llm_configs(
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
+ "is_premium": False,
}
)
for cfg in global_configs:
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@@ -101,7 +103,11 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
- "billing_tier": cfg.get("billing_tier", "free"),
+ "billing_tier": billing_tier,
+ # Mirror chat (``new_llm_config_routes``) so the new-chat
+ # selector's premium badge logic keys off the same
+ # field across chat / image / vision tabs.
+ "is_premium": billing_tier == "premium",
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),
diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py
index facca7b86..4262b2b3f 100644
--- a/surfsense_backend/app/schemas/image_generation.py
+++ b/surfsense_backend/app/schemas/image_generation.py
@@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel):
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
+ is_premium: bool = Field(
+ default=False,
+ description=(
+ "Convenience boolean derived server-side from "
+ "``billing_tier == 'premium'``. The new-chat model selector "
+ "keys its Free/Premium badge off this field for parity with "
+ "chat (`GlobalLLMConfigRead.is_premium`)."
+ ),
+ )
quota_reserve_micros: int | None = Field(
default=None,
description=(
diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py
index 9cc1fce58..e64478d38 100644
--- a/surfsense_backend/app/schemas/new_llm_config.py
+++ b/surfsense_backend/app/schemas/new_llm_config.py
@@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
+ # Capability flag derived at the API boundary (no DB column). Default
+ # True matches the conservative-allow stance — a BYOK row that the
+ # route forgot to augment is not pre-judged. The streaming-task
+ # safety net is the only place a False actually blocks a request.
+ supports_image_input: bool = Field(
+ default=True,
+ description=(
+ "Whether the BYOK chat config can accept image inputs. Derived "
+ "at the route boundary from LiteLLM's authoritative model map "
+ "(``litellm.supports_vision``) — there is no DB column. "
+ "Default True is the conservative-allow stance for unknown / "
+ "unmapped models."
+ ),
+ )
model_config = ConfigDict(from_attributes=True)
@@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
+ # Capability flag derived at the API boundary (see NewLLMConfigRead).
+ supports_image_input: bool = Field(
+ default=True,
+ description=(
+ "Whether the BYOK chat config can accept image inputs. Derived "
+ "at the route boundary from LiteLLM's authoritative model map. "
+ "Default True is the conservative-allow stance."
+ ),
+ )
model_config = ConfigDict(from_attributes=True)
@@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
+ supports_image_input: bool = Field(
+ default=True,
+ description=(
+ "Whether the model accepts image inputs (multimodal vision). "
+ "Derived server-side: OpenRouter dynamic configs use "
+ "``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
+ "authoritative model map (``litellm.supports_vision``). The "
+ "new-chat selector hints with a 'No image' badge when this is "
+ "False and there are pending image attachments. The streaming "
+ "task fails fast only when LiteLLM *explicitly* marks a model "
+ "as text-only — unknown / unmapped models default-allow."
+ ),
+ )
# =============================================================================
diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py
index e55333a9d..d0eeaf5c6 100644
--- a/surfsense_backend/app/schemas/vision_llm.py
+++ b/surfsense_backend/app/schemas/vision_llm.py
@@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel):
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
+ is_premium: bool = Field(
+ default=False,
+ description=(
+ "Convenience boolean derived server-side from "
+ "``billing_tier == 'premium'``. The new-chat model selector "
+ "keys its Free/Premium badge off this field for parity with "
+ "chat (`GlobalLLMConfigRead.is_premium`)."
+ ),
+ )
quota_reserve_tokens: int | None = Field(
default=None,
description=(
diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py
index 3a2c681b7..4f045ba02 100644
--- a/surfsense_backend/app/services/auto_model_pin_service.py
+++ b/surfsense_backend/app/services/auto_model_pin_service.py
@@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
_healthy_until.pop(int(config_id), None)
-def _global_candidates() -> list[dict]:
+def _cfg_supports_image_input(cfg: dict) -> bool:
+ """True if the global cfg can accept image inputs.
+
+ Prefers the explicit ``supports_image_input`` flag (set by the YAML
+ loader / OpenRouter integration). Falls back to a LiteLLM lookup so
+ a YAML entry whose flag was somehow stripped doesn't get wrongly
+ excluded. Default-allows on unknown — the streaming-task safety net
+ is the actual block, not this filter.
+ """
+ if "supports_image_input" in cfg:
+ return bool(cfg.get("supports_image_input"))
+ # Lazy import: provider_capabilities -> llm_config -> services chain;
+ # importing at module load would create an init-order cycle through
+ # ``app.config``.
+ from app.services.provider_capabilities import derive_supports_image_input
+
+ cfg_litellm_params = cfg.get("litellm_params") or {}
+ base_model = (
+ cfg_litellm_params.get("base_model")
+ if isinstance(cfg_litellm_params, dict)
+ else None
+ )
+ return derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+
+
+def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
"""Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts).
+
+ When ``requires_image_input`` is True (image turn), additionally
+ filters out configs whose ``supports_image_input`` resolves to False
+ so a text-only deployment can't be pinned for an image request.
"""
candidates = [
cfg
@@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
+ and (not requires_image_input or _cfg_supports_image_input(cfg))
]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
@@ -237,11 +272,20 @@ async def resolve_or_get_pinned_llm_config_id(
selected_llm_config_id: int,
force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None,
+ requires_image_input: bool = False,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
For non-auto selections, this function clears any existing pin and returns
the selected id as-is.
+
+ When ``requires_image_input`` is True (the current turn carries an
+ ``image_url`` block), the candidate pool is filtered to vision-capable
+ cfgs and any existing pin that can't accept image input is treated as
+ invalid (force re-pin). If no vision-capable cfg is available the
+ function raises ``ValueError`` so the streaming task surfaces the same
+ friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
+ silently routing the image to a text-only deployment.
"""
thread = (
(
@@ -274,14 +318,24 @@ async def resolve_or_get_pinned_llm_config_id(
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
- c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
+ c
+ for c in _global_candidates(requires_image_input=requires_image_input)
+ if int(c.get("id", 0)) not in excluded_ids
]
if not candidates:
+ if requires_image_input:
+ # Distinguish the "no vision-capable cfg" case from generic
+ # "no usable cfg" so the streaming task can map this to the
+ # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
+ raise ValueError(
+ "No vision-capable global LLM configs are available for Auto mode"
+ )
raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent
- # tier switch), unless the caller explicitly requests a forced repin to free.
+ # tier switch), unless the caller explicitly requests a forced repin to free
+ # *or* the turn requires image input but the pin can't handle it.
pinned_id = thread.pinned_llm_config_id
if (
not force_repin_free
@@ -311,6 +365,29 @@ async def resolve_or_get_pinned_llm_config_id(
from_existing_pin=True,
)
if pinned_id is not None:
+ # If the pin is *only* invalid because it can't handle the image
+ # turn (it's still a healthy, usable config in the broader pool),
+ # log that explicitly so operators can correlate the re-pin with
+ # the user's image attachment instead of suspecting a cooldown.
+ if requires_image_input:
+ try:
+ pinned_global = next(
+ c
+ for c in config.GLOBAL_LLM_CONFIGS
+ if int(c.get("id", 0)) == int(pinned_id)
+ )
+ except StopIteration:
+ pinned_global = None
+ if pinned_global is not None and not _cfg_supports_image_input(
+ pinned_global
+ ):
+ logger.info(
+ "auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
+ "previous_config_id=%s",
+ thread_id,
+ search_space_id,
+ pinned_id,
+ )
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,
@@ -327,6 +404,10 @@ async def resolve_or_get_pinned_llm_config_id(
eligible = [c for c in candidates if _tier_of(c) != "premium"]
if not eligible:
+ if requires_image_input:
+ raise ValueError(
+ "Auto mode could not find a vision-capable LLM config for this user and quota state"
+ )
raise ValueError(
"Auto mode could not find an eligible LLM config for this user and quota state"
)
diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py
index f5ca9818e..92ccd6a78 100644
--- a/surfsense_backend/app/services/billable_calls.py
+++ b/surfsense_backend/app/services/billable_calls.py
@@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B):
-1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
- argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
- insert each run inside their own ``shielded_async_session()``. This
- guarantees that a quota commit/rollback can never accidentally flush or
- roll back rows the caller has staged in the request's main session
- (e.g. a freshly-created ``ImageGeneration`` row).
+1. **Session isolation.** ``billable_call`` takes no caller transaction.
+ All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
+ inside their own session context. Route callers use
+ ``shielded_async_session()`` by default; Celery callers can provide a
+ worker-loop-safe session factory. This guarantees that quota
+ commit/rollback can never accidentally flush or roll back rows the caller
+ has staged in its main session (e.g. a freshly-created
+ ``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
@@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B):
from __future__ import annotations
+import asyncio
import logging
-from collections.abc import AsyncIterator
-from contextlib import asynccontextmanager
+from collections.abc import AsyncIterator, Callable
+from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
from typing import Any
from uuid import UUID, uuid4
@@ -58,6 +61,12 @@ from app.services.token_tracking_service import (
logger = logging.getLogger(__name__)
+AUDIT_TIMEOUT_SECONDS = 10.0
+BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
+ {"video_presentation_generation", "podcast_generation"}
+)
+BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
+
class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
@@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception):
)
+class BillingSettlementError(Exception):
+ """Raised when a premium call completed but credit settlement failed."""
+
+ def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
+ self.usage_type = usage_type
+ self.user_id = user_id
+ super().__init__(
+ f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
+ )
+
+
+async def _rollback_safely(session: AsyncSession) -> None:
+ rollback = getattr(session, "rollback", None)
+ if rollback is not None:
+ with suppress(Exception):
+ await rollback()
+
+
+async def _record_audit_best_effort(
+ *,
+ session_factory: BillableSessionFactory,
+ usage_type: str,
+ search_space_id: int,
+ user_id: UUID,
+ prompt_tokens: int,
+ completion_tokens: int,
+ total_tokens: int,
+ cost_micros: int,
+ model_breakdown: dict[str, Any],
+ call_details: dict[str, Any] | None,
+ thread_id: int | None,
+ message_id: int | None,
+ audit_label: str,
+ timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
+) -> None:
+ """Persist a TokenUsage row without letting audit failure block callers.
+
+ Premium settlement is mandatory, but TokenUsage is an audit trail. If the
+ audit insert or commit hangs, user-facing artifacts such as videos and
+ podcasts must still be able to transition to READY after settlement.
+ """
+ audit_thread_id = (
+ None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
+ )
+
+ async def _persist() -> None:
+ logger.info(
+ "[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
+ "total_tokens=%d cost_micros=%d",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ total_tokens,
+ cost_micros,
+ )
+ async with session_factory() as audit_session:
+ try:
+ await record_token_usage(
+ audit_session,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ model_breakdown=model_breakdown,
+ call_details=call_details,
+ thread_id=audit_thread_id,
+ message_id=message_id,
+ )
+ logger.info(
+ "[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ )
+ await audit_session.commit()
+ logger.info(
+ "[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ )
+ except BaseException:
+ await _rollback_safely(audit_session)
+ raise
+
+ try:
+ await asyncio.wait_for(_persist(), timeout=timeout_seconds)
+ except TimeoutError:
+ logger.warning(
+ "[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
+ "timeout=%.1fs total_tokens=%d cost_micros=%d",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ timeout_seconds,
+ total_tokens,
+ cost_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
+ "total_tokens=%d cost_micros=%d",
+ audit_label,
+ usage_type,
+ user_id,
+ audit_thread_id,
+ total_tokens,
+ cost_micros,
+ )
+
+
@asynccontextmanager
async def billable_call(
*,
@@ -101,6 +228,8 @@ async def billable_call(
thread_id: int | None = None,
message_id: int | None = None,
call_details: dict[str, Any] | None = None,
+ billable_session_factory: BillableSessionFactory | None = None,
+ audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call.
@@ -124,6 +253,13 @@ async def billable_call(
thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``.
+ billable_session_factory: Optional async context factory used for
+ reserve/finalize/release/audit sessions. Defaults to
+ ``shielded_async_session`` for route callers; Celery callers pass
+ a worker-loop-safe session factory.
+ audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
+ Audit failure is best-effort and does not undo successful
+ settlement.
Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
@@ -134,6 +270,7 @@ async def billable_call(
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
"""
is_premium = billing_tier == "premium"
+ session_factory = billable_session_factory or shielded_async_session
async with scoped_turn() as acc:
# ---------- Free path: just audit -------------------------------
@@ -143,30 +280,22 @@ async def billable_call(
finally:
# Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later.
- try:
- async with shielded_async_session() as audit_session:
- await record_token_usage(
- audit_session,
- usage_type=usage_type,
- search_space_id=search_space_id,
- user_id=user_id,
- prompt_tokens=acc.total_prompt_tokens,
- completion_tokens=acc.total_completion_tokens,
- total_tokens=acc.grand_total,
- cost_micros=acc.total_cost_micros,
- model_breakdown=acc.per_message_summary(),
- call_details=call_details,
- thread_id=thread_id,
- message_id=message_id,
- )
- await audit_session.commit()
- except Exception:
- logger.exception(
- "[billable_call] free-path audit insert failed for "
- "usage_type=%s user_id=%s",
- usage_type,
- user_id,
- )
+ await _record_audit_best_effort(
+ session_factory=session_factory,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=acc.total_cost_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ audit_label="free",
+ timeout_seconds=audit_timeout_seconds,
+ )
return
# ---------- Premium path: reserve → execute → finalize ----------
@@ -180,7 +309,7 @@ async def billable_call(
request_id = str(uuid4())
- async with shielded_async_session() as quota_session:
+ async with session_factory() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=user_id,
@@ -222,7 +351,7 @@ async def billable_call(
# from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases.
try:
- async with shielded_async_session() as quota_session:
+ async with session_factory() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
@@ -241,7 +370,16 @@ async def billable_call(
# ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros
try:
- async with shielded_async_session() as quota_session:
+ logger.info(
+ "[billable_call] finalize start user=%s usage_type=%s actual=%d "
+ "reserved=%d thread=%s",
+ user_id,
+ usage_type,
+ actual_micros,
+ reserve_micros,
+ thread_id,
+ )
+ async with session_factory() as quota_session:
final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=user_id,
@@ -260,7 +398,7 @@ async def billable_call(
final_result.limit,
final_result.remaining,
)
- except Exception:
+ except Exception as finalize_exc:
# Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak.
logger.exception(
@@ -269,7 +407,7 @@ async def billable_call(
user_id,
)
try:
- async with shielded_async_session() as quota_session:
+ async with session_factory() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
@@ -281,31 +419,28 @@ async def billable_call(
"for user=%s",
user_id,
)
+ raise BillingSettlementError(
+ usage_type=usage_type,
+ user_id=user_id,
+ cause=finalize_exc,
+ ) from finalize_exc
- try:
- async with shielded_async_session() as audit_session:
- await record_token_usage(
- audit_session,
- usage_type=usage_type,
- search_space_id=search_space_id,
- user_id=user_id,
- prompt_tokens=acc.total_prompt_tokens,
- completion_tokens=acc.total_completion_tokens,
- total_tokens=acc.grand_total,
- cost_micros=actual_micros,
- model_breakdown=acc.per_message_summary(),
- call_details=call_details,
- thread_id=thread_id,
- message_id=message_id,
- )
- await audit_session.commit()
- except Exception:
- logger.exception(
- "[billable_call] premium-path audit insert failed for "
- "usage_type=%s user_id=%s (debit was applied)",
- usage_type,
- user_id,
- )
+ await _record_audit_best_effort(
+ session_factory=session_factory,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=actual_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ audit_label="premium",
+ timeout_seconds=audit_timeout_seconds,
+ )
async def _resolve_agent_billing_for_search_space(
@@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space(
__all__ = [
+ "BillingSettlementError",
"QuotaInsufficientError",
"_resolve_agent_billing_for_search_space",
"billable_call",
diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py
index f45a6ab63..b4de2a0bf 100644
--- a/surfsense_backend/app/services/image_gen_router_service.py
+++ b/surfsense_backend/app/services/image_gen_router_service.py
@@ -20,6 +20,8 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
+from app.services.provider_api_base import resolve_api_base
+
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
@@ -152,12 +154,12 @@ class ImageGenRouterService:
return None
# Build model string
+ provider = config.get("provider", "").upper()
if config.get("custom_provider"):
- model_string = f"{config['custom_provider']}/{config['model_name']}"
+ provider_prefix = config["custom_provider"]
else:
- provider = config.get("provider", "").upper()
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
- model_string = f"{provider_prefix}/{config['model_name']}"
+ model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
@@ -165,9 +167,16 @@ class ImageGenRouterService:
"api_key": config.get("api_key"),
}
- # Add optional api_base
- if config.get("api_base"):
- litellm_params["api_base"] = config["api_base"]
+ # Resolve ``api_base`` so deployments don't silently inherit
+ # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
+ # the wrong provider (see ``provider_api_base`` docstring).
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
+ if api_base:
+ litellm_params["api_base"] = api_base
# Add api_version (required for Azure)
if config.get("api_version"):
diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py
index 1e9d235c8..d220aa346 100644
--- a/surfsense_backend/app/services/llm_router_service.py
+++ b/surfsense_backend/app/services/llm_router_service.py
@@ -140,8 +140,6 @@ PROVIDER_MAP = {
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
- PROVIDER_DEFAULT_API_BASE,
- PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py
index 72c10035d..ade202c72 100644
--- a/surfsense_backend/app/services/llm_service.py
+++ b/surfsense_backend/app/services/llm_service.py
@@ -16,6 +16,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm,
is_auto_mode,
)
+from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
@@ -556,22 +557,26 @@ async def get_vision_llm(
return None
if global_cfg.get("custom_provider"):
- model_string = (
- f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
- )
+ provider_prefix = global_cfg["custom_provider"]
+ model_string = f"{provider_prefix}/{global_cfg['model_name']}"
else:
- prefix = VISION_PROVIDER_MAP.get(
+ provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(),
global_cfg["provider"].lower(),
)
- model_string = f"{prefix}/{global_cfg['model_name']}"
+ model_string = f"{provider_prefix}/{global_cfg['model_name']}"
litellm_kwargs = {
"model": model_string,
"api_key": global_cfg["api_key"],
}
- if global_cfg.get("api_base"):
- litellm_kwargs["api_base"] = global_cfg["api_base"]
+ api_base = resolve_api_base(
+ provider=global_cfg.get("provider"),
+ provider_prefix=provider_prefix,
+ config_api_base=global_cfg.get("api_base"),
+ )
+ if api_base:
+ litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
@@ -606,20 +611,26 @@ async def get_vision_llm(
return None
if vision_cfg.custom_provider:
- model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
+ provider_prefix = vision_cfg.custom_provider
+ model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else:
- prefix = VISION_PROVIDER_MAP.get(
+ provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(),
)
- model_string = f"{prefix}/{vision_cfg.model_name}"
+ model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = {
"model": model_string,
"api_key": vision_cfg.api_key,
}
- if vision_cfg.api_base:
- litellm_kwargs["api_base"] = vision_cfg.api_base
+ api_base = resolve_api_base(
+ provider=vision_cfg.provider.value,
+ provider_prefix=provider_prefix,
+ config_api_base=vision_cfg.api_base,
+ )
+ if api_base:
+ litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py
index 0d030f04f..6454e2d58 100644
--- a/surfsense_backend/app/services/openrouter_integration_service.py
+++ b/surfsense_backend/app/services/openrouter_integration_service.py
@@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool:
return "image" in input_mods and "text" in output_mods
+def _supports_image_input(model: dict) -> bool:
+ """Return True if the model accepts ``image`` in its input modalities.
+
+ Differs from :func:`_is_vision_input_model` in that it does NOT
+ require text output — chat-tab models always emit text already (the
+ chat catalog filters by ``_is_text_output_model``), so the only
+ extra capability we need to track per chat config is whether the
+ model can ingest user-attached images. The chat selector and the
+ streaming task both key off this flag to prevent hitting an
+ OpenRouter 404 ``"No endpoints found that support image input"``
+ when the user uploads an image and selects a text-only model
+ (DeepSeek V3, Llama 3.x base, etc.).
+ """
+ arch = model.get("architecture", {}) or {}
+ input_mods = arch.get("input_modalities", []) or []
+ return "image" in input_mods
+
+
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@@ -321,6 +339,13 @@ def _generate_configs(
# account-wide quota, so per-deployment routing can't spread load
# there — it just drains the shared bucket faster.
"router_pool_eligible": tier == "premium",
+ # Capability flag derived from ``architecture.input_modalities``.
+ # Read by the new-chat selector to dim image-incompatible models
+ # when the user has pending image attachments, and by
+ # ``stream_new_chat`` as a fail-fast safety net before the
+ # OpenRouter request would otherwise 404 with
+ # ``"No endpoints found that support image input"``.
+ "supports_image_input": _supports_image_input(model),
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
@@ -398,7 +423,12 @@ def _generate_image_gen_configs(
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
- "api_base": "",
+ # Pin to OpenRouter's public base URL so a downstream call site
+ # that forgets ``resolve_api_base`` still doesn't inherit
+ # ``AZURE_OPENAI_ENDPOINT`` and 404 on
+ # ``image_generation/transformation`` (defense-in-depth, see
+ # ``provider_api_base`` docstring).
+ "api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params),
@@ -477,7 +507,11 @@ def _generate_vision_llm_configs(
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
- "api_base": "",
+ # Pin to OpenRouter's public base URL so a downstream call site
+ # that forgets ``resolve_api_base`` still doesn't inherit
+ # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
+ # ``provider_api_base`` docstring).
+ "api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py
index 979d7d3a1..dca1f9462 100644
--- a/surfsense_backend/app/services/provider_api_base.py
+++ b/surfsense_backend/app/services/provider_api_base.py
@@ -17,7 +17,6 @@ source of truth without an inter-service circular import.
from __future__ import annotations
-
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py
new file mode 100644
index 000000000..e9a1c33e1
--- /dev/null
+++ b/surfsense_backend/app/services/provider_capabilities.py
@@ -0,0 +1,280 @@
+"""Capability resolution shared by chat / image / vision call sites.
+
+Why this exists
+---------------
+The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
+single, authoritative answer to one question: *can this chat config accept
+``image_url`` content blocks?* Without it, the new-chat selector can't badge
+incompatible models and the streaming task can't fail fast with a friendly
+error before sending an image to a text-only provider.
+
+Two functions, two intents:
+
+- :func:`derive_supports_image_input` — best-effort *True* for catalog and
+ UI surfacing. Default-allow: an unknown / unmapped model is treated as
+ capable so we never lock the user out of a freshly added or
+ third-party-hosted vision model.
+
+- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming
+ task's safety net. Returns True only when LiteLLM's model map *explicitly*
+ sets ``supports_vision=False`` (or its bare-name variant does). Anything
+ else — missing key, lookup exception, ``supports_vision=True`` — returns
+ False so the request flows through to the provider.
+
+Implementation rule: only public LiteLLM symbols
+------------------------------------------------
+``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
+typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
+across releases. The private ``_is_explicitly_disabled_factory`` and
+``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
+can't silently break us.
+
+Why the previous round's strict YAML opt-in flag failed
+-------------------------------------------------------
+``supports_image_input: false`` was the YAML loader's setdefault. Operators
+maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
+YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to
+False and the streaming gate rejected every image turn. Sourcing capability
+from LiteLLM's authoritative model map (which already says
+``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Iterable
+
+import litellm
+
+logger = logging.getLogger(__name__)
+
+
+# Provider-name → LiteLLM model-prefix map.
+#
+# Owned here because ``app.services.provider_capabilities`` is the
+# only edge that's safe to call from ``app.config``'s YAML loader at
+# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
+# this constant under the historical ``PROVIDER_MAP`` name; placing the
+# map there directly would re-introduce the
+# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
+# app.config`` cycle that prompted the move.
+_PROVIDER_PREFIX_MAP: dict[str, str] = {
+ "OPENAI": "openai",
+ "ANTHROPIC": "anthropic",
+ "GROQ": "groq",
+ "COHERE": "cohere",
+ "GOOGLE": "gemini",
+ "OLLAMA": "ollama_chat",
+ "MISTRAL": "mistral",
+ "AZURE_OPENAI": "azure",
+ "OPENROUTER": "openrouter",
+ "XAI": "xai",
+ "BEDROCK": "bedrock",
+ "VERTEX_AI": "vertex_ai",
+ "TOGETHER_AI": "together_ai",
+ "FIREWORKS_AI": "fireworks_ai",
+ "DEEPSEEK": "openai",
+ "ALIBABA_QWEN": "openai",
+ "MOONSHOT": "openai",
+ "ZHIPU": "openai",
+ "GITHUB_MODELS": "github",
+ "REPLICATE": "replicate",
+ "PERPLEXITY": "perplexity",
+ "ANYSCALE": "anyscale",
+ "DEEPINFRA": "deepinfra",
+ "CEREBRAS": "cerebras",
+ "SAMBANOVA": "sambanova",
+ "AI21": "ai21",
+ "CLOUDFLARE": "cloudflare",
+ "DATABRICKS": "databricks",
+ "COMETAPI": "cometapi",
+ "HUGGINGFACE": "huggingface",
+ "MINIMAX": "openai",
+ "CUSTOM": "custom",
+}
+
+
+def _candidate_model_strings(
+ *,
+ provider: str | None,
+ model_name: str | None,
+ base_model: str | None,
+ custom_provider: str | None,
+) -> list[tuple[str, str | None]]:
+ """Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
+
+ LiteLLM's capability lookup is keyed by ``model`` + (optional)
+ ``custom_llm_provider``. Different config sources give us different
+ levels of detail, so we try the most-specific keys first and fall back
+ to bare model names so unannotated entries (e.g. an Azure deployment
+ pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
+ map. Order matters — the first lookup that returns a definitive answer
+ wins for both helpers.
+ """
+ candidates: list[tuple[str, str | None]] = []
+ seen: set[tuple[str, str | None]] = set()
+
+ def _add(model: str | None, llm_provider: str | None) -> None:
+ if not model:
+ return
+ key = (model, llm_provider)
+ if key in seen:
+ return
+ seen.add(key)
+ candidates.append(key)
+
+ provider_prefix: str | None = None
+ if provider:
+ provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
+ if custom_provider:
+ # ``custom_provider`` overrides everything for CUSTOM/proxy setups.
+ provider_prefix = custom_provider
+
+ primary_model = base_model or model_name
+ bare_model = model_name
+
+ # Most-specific first: provider-prefixed identifier with explicit
+ # custom_llm_provider so LiteLLM won't have to guess the provider via
+ # ``get_llm_provider``.
+ if primary_model and provider_prefix:
+ # e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
+ if "/" in primary_model:
+ _add(primary_model, provider_prefix)
+ else:
+ _add(f"{provider_prefix}/{primary_model}", provider_prefix)
+
+ # Bare base_model (or model_name) with provider hint — handles entries
+ # the upstream map keys without a provider prefix (most ``gpt-*`` and
+ # ``claude-*`` entries do this).
+ if primary_model:
+ _add(primary_model, provider_prefix)
+
+ # Fallback to model_name when base_model differs (e.g. an Azure
+ # deployment whose model_name is the deployment id but base_model is the
+ # canonical OpenAI sku).
+ if bare_model and bare_model != primary_model:
+ if provider_prefix and "/" not in bare_model:
+ _add(f"{provider_prefix}/{bare_model}", provider_prefix)
+ _add(bare_model, provider_prefix)
+ _add(bare_model, None)
+
+ return candidates
+
+
+def derive_supports_image_input(
+ *,
+ provider: str | None = None,
+ model_name: str | None = None,
+ base_model: str | None = None,
+ custom_provider: str | None = None,
+ openrouter_input_modalities: Iterable[str] | None = None,
+) -> bool:
+ """Best-effort capability flag for the new-chat selector and catalog.
+
+ Resolution order (first definitive answer wins):
+
+ 1. ``openrouter_input_modalities`` (when provided as a non-empty
+ iterable). OpenRouter exposes ``architecture.input_modalities`` per
+ model and that's the authoritative source for OR dynamic configs.
+ 2. ``litellm.supports_vision`` against each candidate identifier from
+ :func:`_candidate_model_strings`. Returns True as soon as any
+ candidate confirms vision support.
+ 3. Default ``True`` — the conservative-allow stance. An unknown /
+ newly-added / third-party-hosted model is *not* pre-judged. The
+ streaming safety net (:func:`is_known_text_only_chat_model`) is the
+ only place a False ever blocks; everywhere else, a False here would
+ just hide a usable model from the user.
+
+ Returns:
+ True if the model can plausibly accept image input, False only when
+ OpenRouter explicitly says it can't.
+ """
+ if openrouter_input_modalities is not None:
+ modalities = list(openrouter_input_modalities)
+ if modalities:
+ return "image" in modalities
+ # Empty list explicitly published by OR — treat as "no image".
+ return False
+
+ for model_string, custom_llm_provider in _candidate_model_strings(
+ provider=provider,
+ model_name=model_name,
+ base_model=base_model,
+ custom_provider=custom_provider,
+ ):
+ try:
+ if litellm.supports_vision(
+ model=model_string, custom_llm_provider=custom_llm_provider
+ ):
+ return True
+ except Exception as exc:
+ logger.debug(
+ "litellm.supports_vision raised for model=%s provider=%s: %s",
+ model_string,
+ custom_llm_provider,
+ exc,
+ )
+ continue
+
+ # Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
+ return True
+
+
+def is_known_text_only_chat_model(
+ *,
+ provider: str | None = None,
+ model_name: str | None = None,
+ base_model: str | None = None,
+ custom_provider: str | None = None,
+) -> bool:
+ """Strict opt-out probe for the streaming-task safety net.
+
+ Returns True only when LiteLLM's model map *explicitly* sets
+ ``supports_vision=False`` for at least one candidate identifier. Missing
+ key, lookup exception, or ``supports_vision=True`` all return False so
+ the streaming task lets the request through. This is the inverse-default
+ of :func:`derive_supports_image_input`.
+
+ Why two functions
+ -----------------
+ The selector wants "show me everything that's plausibly capable" —
+ default-allow. The safety net wants "block only when I'm certain it
+ can't" — default-pass. Mixing the two intents in a single function
+ leads to the regression we're fixing here.
+ """
+ for model_string, custom_llm_provider in _candidate_model_strings(
+ provider=provider,
+ model_name=model_name,
+ base_model=base_model,
+ custom_provider=custom_provider,
+ ):
+ try:
+ info = litellm.get_model_info(
+ model=model_string, custom_llm_provider=custom_llm_provider
+ )
+ except Exception as exc:
+ logger.debug(
+ "litellm.get_model_info raised for model=%s provider=%s: %s",
+ model_string,
+ custom_llm_provider,
+ exc,
+ )
+ continue
+
+ # ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
+ # may be missing, None, True, or False. We only fire on explicit
+ # False — None / missing / True all mean "don't block".
+ try:
+ value = info.get("supports_vision") # type: ignore[union-attr]
+ except AttributeError:
+ value = None
+ if value is False:
+ return True
+
+ return False
+
+
+__all__ = [
+ "derive_supports_image_input",
+ "is_known_text_only_chat_model",
+]
diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py
index 5b1f2cd13..b23359f36 100644
--- a/surfsense_backend/app/tasks/celery_tasks/__init__.py
+++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py
@@ -1,10 +1,25 @@
-"""Celery tasks package."""
+"""Celery tasks package.
+
+Also hosts the small helpers every async celery task should use to
+spin up its event loop. See :func:`run_async_celery_task` for the
+canonical pattern.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import logging
+from collections.abc import Awaitable, Callable
+from typing import TypeVar
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
+logger = logging.getLogger(__name__)
+
_celery_engine = None
_celery_session_maker = None
@@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
_celery_engine, expire_on_commit=False
)
return _celery_session_maker
+
+
+def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
+ """Drop the shared ``app.db.engine`` connection pool synchronously.
+
+ The shared engine (used by ``shielded_async_session`` and most
+ routes / services) is a module-level singleton with a real pool.
+ Each celery task creates a fresh ``asyncio`` event loop; asyncpg
+ connections cache a reference to whichever loop opened them. When
+ a subsequent task's loop pulls a stale connection from the pool,
+ SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
+
+ AttributeError: 'NoneType' object has no attribute 'send'
+ File ".../asyncio/proactor_events.py", line 402, in _loop_writing
+ self._write_fut = self._loop._proactor.send(self._sock, data)
+
+ or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
+ coroutine that can never run because its loop is gone.
+
+ Disposing the engine forces the pool to drop every cached
+ connection so the next checkout opens a fresh one on the current
+ loop. Safe to call from a task's finally block; failure is logged
+ but never propagated.
+ """
+ try:
+ from app.db import engine as shared_engine
+
+ loop.run_until_complete(shared_engine.dispose())
+ except Exception:
+ logger.warning("Shared DB engine dispose() failed", exc_info=True)
+
+
+T = TypeVar("T")
+
+
+def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
+ """Run an async coroutine inside a fresh event loop with proper
+ DB-engine cleanup.
+
+ This is the canonical entry point for every async celery task.
+ It performs three responsibilities that were previously copy-pasted
+ (incorrectly) across each task module:
+
+ 1. Create a fresh ``asyncio`` loop and install it on the current
+ thread (celery's ``--pool=solo`` runs every task on the main
+ thread, but other pool types don't).
+ 2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
+ any stale connections left over from a previous task's loop
+ are dropped — defends against tasks that crashed without
+ cleaning up.
+ 3. Dispose the shared engine AFTER the task runs so the
+ connections we opened on this loop are released before the
+ loop closes (avoids ``coroutine 'Connection._cancel' was
+ never awaited`` warnings and the next-task hang).
+
+ Use as::
+
+ @celery_app.task(name="my_task", bind=True)
+ def my_task(self, *args):
+ return run_async_celery_task(lambda: _my_task_impl(*args))
+ """
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ # Defense-in-depth: prior task may have crashed before
+ # disposing. Idempotent — no-op if pool is already empty.
+ _dispose_shared_db_engine(loop)
+ return loop.run_until_complete(coro_factory())
+ finally:
+ # Drop any connections this task opened so they don't leak
+ # into the next task's loop.
+ _dispose_shared_db_engine(loop)
+ with contextlib.suppress(Exception):
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ with contextlib.suppress(Exception):
+ asyncio.set_event_loop(None)
+ loop.close()
+
+
+__all__ = [
+ "get_celery_session_maker",
+ "run_async_celery_task",
+]
diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
index fe1ac19d3..08d96cfa0 100644
--- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
@@ -4,7 +4,7 @@ import logging
import traceback
from app.celery_app import celery_app
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -49,22 +49,15 @@ def index_notion_pages_task(
end_date: str,
):
"""Celery task to index Notion pages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _index_notion_pages(
+ return run_async_celery_task(
+ lambda: _index_notion_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id)
raise
- finally:
- loop.close()
async def _index_notion_pages(
@@ -95,19 +88,11 @@ def index_github_repos_task(
end_date: str,
):
"""Celery task to index GitHub repositories."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_github_repos(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_github_repos(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_github_repos(
@@ -138,19 +123,11 @@ def index_confluence_pages_task(
end_date: str,
):
"""Celery task to index Confluence pages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_confluence_pages(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_confluence_pages(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_confluence_pages(
@@ -181,22 +158,15 @@ def index_google_calendar_events_task(
end_date: str,
):
"""Celery task to index Google Calendar events."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _index_google_calendar_events(
+ return run_async_celery_task(
+ lambda: _index_google_calendar_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise
- finally:
- loop.close()
async def _index_google_calendar_events(
@@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
end_date: str,
):
"""Celery task to index Google Gmail messages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_google_gmail_messages(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_google_gmail_messages(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_google_gmail_messages(
@@ -269,22 +231,14 @@ def index_google_drive_files_task(
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
):
"""Celery task to index Google Drive folders and files."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_google_drive_files(
- connector_id,
- search_space_id,
- user_id,
- items_dict,
- )
+ return run_async_celery_task(
+ lambda: _index_google_drive_files(
+ connector_id,
+ search_space_id,
+ user_id,
+ items_dict,
)
- finally:
- loop.close()
+ )
async def _index_google_drive_files(
@@ -317,22 +271,14 @@ def index_onedrive_files_task(
items_dict: dict,
):
"""Celery task to index OneDrive folders and files."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_onedrive_files(
- connector_id,
- search_space_id,
- user_id,
- items_dict,
- )
+ return run_async_celery_task(
+ lambda: _index_onedrive_files(
+ connector_id,
+ search_space_id,
+ user_id,
+ items_dict,
)
- finally:
- loop.close()
+ )
async def _index_onedrive_files(
@@ -365,22 +311,14 @@ def index_dropbox_files_task(
items_dict: dict,
):
"""Celery task to index Dropbox folders and files."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_dropbox_files(
- connector_id,
- search_space_id,
- user_id,
- items_dict,
- )
+ return run_async_celery_task(
+ lambda: _index_dropbox_files(
+ connector_id,
+ search_space_id,
+ user_id,
+ items_dict,
)
- finally:
- loop.close()
+ )
async def _index_dropbox_files(
@@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
end_date: str,
):
"""Celery task to index Elasticsearch documents."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_elasticsearch_documents(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_elasticsearch_documents(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_elasticsearch_documents(
@@ -457,22 +387,15 @@ def index_crawled_urls_task(
end_date: str,
):
"""Celery task to index Web page Urls."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _index_crawled_urls(
+ return run_async_celery_task(
+ lambda: _index_crawled_urls(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise
- finally:
- loop.close()
async def _index_crawled_urls(
@@ -503,19 +426,11 @@ def index_bookstack_pages_task(
end_date: str,
):
"""Celery task to index BookStack pages."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_bookstack_pages(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_bookstack_pages(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_bookstack_pages(
@@ -546,19 +461,11 @@ def index_composio_connector_task(
end_date: str | None,
):
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_composio_connector(
- connector_id, search_space_id, user_id, start_date, end_date
- )
+ return run_async_celery_task(
+ lambda: _index_composio_connector(
+ connector_id, search_space_id, user_id, start_date, end_date
)
- finally:
- loop.close()
+ )
async def _index_composio_connector(
diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
index c2dbe7700..5d6bde6c1 100644
--- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
@@ -11,7 +11,7 @@ from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
document_id: ID of document to reindex
user_id: ID of user who edited the document
"""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_reindex_document(document_id, user_id))
- finally:
- loop.close()
+ return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
async def _reindex_document(document_id: int, user_id: str):
diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
index 9d12f91f6..c78e376bd 100644
--- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
@@ -11,7 +11,7 @@ from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.tasks.connector_indexers.local_folder_indexer import (
index_local_folder,
index_uploaded_files,
@@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
)
def delete_document_task(self, document_id: int):
"""Celery task to delete a document and its chunks in batches."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(_delete_document_background(document_id))
- finally:
- loop.close()
+ return run_async_celery_task(lambda: _delete_document_background(document_id))
async def _delete_document_background(document_id: int) -> None:
@@ -153,14 +148,9 @@ def delete_folder_documents_task(
folder_subtree_ids: list[int] | None = None,
):
"""Celery task to delete documents first, then the folder rows."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _delete_folder_documents(document_ids, folder_subtree_ids)
- )
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
+ )
async def _delete_folder_documents(
@@ -209,12 +199,9 @@ async def _delete_folder_documents(
)
def delete_search_space_task(self, search_space_id: int):
"""Celery task to delete a search space and heavy child rows in batches."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(_delete_search_space_background(search_space_id))
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _delete_search_space_background(search_space_id)
+ )
async def _delete_search_space_background(search_space_id: int) -> None:
@@ -269,18 +256,11 @@ def process_extension_document_task(
search_space_id: ID of the search space
user_id: ID of the user
"""
- # Create a new event loop for this task
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _process_extension_document(
- individual_document_dict, search_space_id, user_id
- )
+ return run_async_celery_task(
+ lambda: _process_extension_document(
+ individual_document_dict, search_space_id, user_id
)
- finally:
- loop.close()
+ )
async def _process_extension_document(
@@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
search_space_id: ID of the search space
user_id: ID of the user
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _process_youtube_video(url, search_space_id, user_id)
+ )
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
@@ -573,12 +549,9 @@ def process_file_upload_task(
except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}")
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _process_file_upload(file_path, filename, search_space_id, user_id)
+ run_async_celery_task(
+ lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
)
logger.info(
f"[process_file_upload] Task completed successfully for: {filename}"
@@ -589,8 +562,6 @@ def process_file_upload_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
- finally:
- loop.close()
async def _process_file_upload(
@@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
"File may have been removed before syncing could start."
)
# Mark document as failed since file is missing
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _mark_document_failed(
- document_id,
- "File not found. Please re-upload the file.",
- )
+ run_async_celery_task(
+ lambda: _mark_document_failed(
+ document_id,
+ "File not found. Please re-upload the file.",
)
- finally:
- loop.close()
+ )
return
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- loop.run_until_complete(
- _process_file_with_document(
+ run_async_celery_task(
+ lambda: _process_file_with_document(
document_id,
temp_path,
filename,
@@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
- finally:
- loop.close()
async def _mark_document_failed(document_id: int, reason: str):
@@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _process_circleback_meeting(
- meeting_id,
- meeting_name,
- markdown_content,
- metadata,
- search_space_id,
- connector_id,
- )
+ return run_async_celery_task(
+ lambda: _process_circleback_meeting(
+ meeting_id,
+ meeting_name,
+ markdown_content,
+ metadata,
+ search_space_id,
+ connector_id,
)
- finally:
- loop.close()
+ )
async def _process_circleback_meeting(
@@ -1291,25 +1246,19 @@ def index_local_folder_task(
target_file_paths: list[str] | None = None,
):
"""Celery task to index a local folder. Config is passed directly — no connector row."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(
- _index_local_folder_async(
- search_space_id=search_space_id,
- user_id=user_id,
- folder_path=folder_path,
- folder_name=folder_name,
- exclude_patterns=exclude_patterns,
- file_extensions=file_extensions,
- root_folder_id=root_folder_id,
- enable_summary=enable_summary,
- target_file_paths=target_file_paths,
- )
+ return run_async_celery_task(
+ lambda: _index_local_folder_async(
+ search_space_id=search_space_id,
+ user_id=user_id,
+ folder_path=folder_path,
+ folder_name=folder_name,
+ exclude_patterns=exclude_patterns,
+ file_extensions=file_extensions,
+ root_folder_id=root_folder_id,
+ enable_summary=enable_summary,
+ target_file_paths=target_file_paths,
)
- finally:
- loop.close()
+ )
async def _index_local_folder_async(
@@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
processing_mode: str = "basic",
):
"""Celery task to index files uploaded from the desktop app."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _index_uploaded_folder_files_async(
- search_space_id=search_space_id,
- user_id=user_id,
- folder_name=folder_name,
- root_folder_id=root_folder_id,
- enable_summary=enable_summary,
- file_mappings=file_mappings,
- use_vision_llm=use_vision_llm,
- processing_mode=processing_mode,
- )
+ return run_async_celery_task(
+ lambda: _index_uploaded_folder_files_async(
+ search_space_id=search_space_id,
+ user_id=user_id,
+ folder_name=folder_name,
+ root_folder_id=root_folder_id,
+ enable_summary=enable_summary,
+ file_mappings=file_mappings,
+ use_vision_llm=use_vision_llm,
+ processing_mode=processing_mode,
)
- finally:
- loop.close()
+ )
async def _index_uploaded_folder_files_async(
@@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _ai_sort_search_space_async(search_space_id, user_id)
+ )
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
@@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _ai_sort_document_async(search_space_id, user_id, document_id)
- )
- finally:
- loop.close()
+ return run_async_celery_task(
+ lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
+ )
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
diff --git a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py
index 98b107af3..c6c8666f5 100644
--- a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py
@@ -2,14 +2,13 @@
from __future__ import annotations
-import asyncio
import logging
from app.celery_app import celery_app
from app.db import SearchSourceConnector
from app.schemas.obsidian_plugin import NotePayload
from app.services.obsidian_plugin_indexer import upsert_note
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
user_id: str,
) -> None:
"""Process one Obsidian non-markdown attachment asynchronously."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- _index_obsidian_attachment(
- connector_id=connector_id,
- payload_data=payload_data,
- user_id=user_id,
- )
+ return run_async_celery_task(
+ lambda: _index_obsidian_attachment(
+ connector_id=connector_id,
+ payload_data=payload_data,
+ user_id=user_id,
)
- finally:
- loop.close()
+ )
async def _index_obsidian_attachment(
diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
index 937877473..8b311576e 100644
--- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
@@ -3,6 +3,7 @@
import asyncio
import logging
import sys
+from contextlib import asynccontextmanager
from sqlalchemy import select
@@ -12,11 +13,12 @@ from app.celery_app import celery_app
from app.config import config as app_config
from app.db import Podcast, PodcastStatus
from app.services.billable_calls import (
+ BillingSettlementError,
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -34,6 +36,13 @@ if sys.platform.startswith("win"):
# =============================================================================
+@asynccontextmanager
+async def _celery_billable_session():
+ """Session factory used by billable_call inside the Celery worker loop."""
+ async with get_celery_session_maker()() as session:
+ yield session
+
+
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
@@ -46,27 +55,22 @@ def generate_content_podcast_task(
Celery task to generate podcast from source content.
Updates existing podcast record created by the tool.
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- result = loop.run_until_complete(
- _generate_content_podcast(
+ return run_async_celery_task(
+ lambda: _generate_content_podcast(
podcast_id,
source_content,
search_space_id,
user_prompt,
)
)
- loop.run_until_complete(loop.shutdown_asyncgens())
- return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
- loop.run_until_complete(_mark_podcast_failed(podcast_id))
+ try:
+ run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
+ except Exception:
+ logger.exception("Failed to mark podcast %s as failed", podcast_id)
return {"status": "failed", "podcast_id": podcast_id}
- finally:
- asyncio.set_event_loop(None)
- loop.close()
async def _mark_podcast_failed(podcast_id: int) -> None:
@@ -148,11 +152,12 @@ async def _generate_content_podcast(
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
usage_type="podcast_generation",
- thread_id=podcast.thread_id,
call_details={
"podcast_id": podcast.id,
"title": podcast.title,
+ "thread_id": podcast.thread_id,
},
+ billable_session_factory=_celery_billable_session,
):
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
@@ -173,6 +178,18 @@ async def _generate_content_podcast(
"podcast_id": podcast.id,
"reason": "premium_quota_exhausted",
}
+ except BillingSettlementError:
+ logger.exception(
+ "Podcast %s: premium billing settlement failed",
+ podcast.id,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "billing_settlement_failed",
+ }
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
@@ -194,7 +211,14 @@ async def _generate_content_podcast(
podcast.podcast_transcript = serializable_transcript
podcast.file_location = file_path
podcast.status = PodcastStatus.READY
+ logger.info(
+ "Podcast %s: committing READY transcript_entries=%d file=%s",
+ podcast.id,
+ len(serializable_transcript),
+ file_path,
+ )
await session.commit()
+ logger.info("Podcast %s: READY commit complete", podcast.id)
logger.info(f"Successfully generated podcast: {podcast.id}")
diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
index 373f04b48..e41251407 100644
--- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
@@ -7,7 +7,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
@@ -20,15 +20,7 @@ def check_periodic_schedules_task():
This task runs every minute and triggers indexing for any connector
whose next_scheduled_at time has passed.
"""
- import asyncio
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_check_and_trigger_schedules())
- finally:
- loop.close()
+ return run_async_celery_task(_check_and_trigger_schedules)
async def _check_and_trigger_schedules():
diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
index e05ae9435..d51c85dee 100644
--- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
@@ -34,7 +34,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
Also marks associated pending/processing documents as failed.
"""
- import asyncio
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ async def _both() -> None:
+ await _cleanup_stale_notifications()
+ await _cleanup_stale_document_processing_notifications()
- try:
- loop.run_until_complete(_cleanup_stale_notifications())
- loop.run_until_complete(_cleanup_stale_document_processing_notifications())
- finally:
- loop.close()
+ return run_async_celery_task(_both)
async def _cleanup_stale_notifications():
diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py
index 3aee1a360..ace6ef7ca 100644
--- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-import asyncio
import logging
from datetime import UTC, datetime, timedelta
@@ -18,7 +17,7 @@ from app.db import (
PremiumTokenPurchaseStatus,
)
from app.routes import stripe_routes
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
def reconcile_pending_stripe_page_purchases_task():
"""Recover paid purchases that were left pending due to missed webhook handling."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_reconcile_pending_page_purchases())
- finally:
- loop.close()
+ return run_async_celery_task(_reconcile_pending_page_purchases)
async def _reconcile_pending_page_purchases() -> None:
@@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
def reconcile_pending_stripe_token_purchases_task():
"""Recover paid token purchases that were left pending due to missed webhook handling."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(_reconcile_pending_token_purchases())
- finally:
- loop.close()
+ return run_async_celery_task(_reconcile_pending_token_purchases)
async def _reconcile_pending_token_purchases() -> None:
diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
index 4f0c427d9..08f22140c 100644
--- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
@@ -3,6 +3,7 @@
import asyncio
import logging
import sys
+from contextlib import asynccontextmanager
from sqlalchemy import select
@@ -12,11 +13,12 @@ from app.celery_app import celery_app
from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
from app.services.billable_calls import (
+ BillingSettlementError,
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
-from app.tasks.celery_tasks import get_celery_session_maker
+from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@@ -29,6 +31,13 @@ if sys.platform.startswith("win"):
)
+@asynccontextmanager
+async def _celery_billable_session():
+ """Session factory used by billable_call inside the Celery worker loop."""
+ async with get_celery_session_maker()() as session:
+ yield session
+
+
@celery_app.task(name="generate_video_presentation", bind=True)
def generate_video_presentation_task(
self,
@@ -41,27 +50,30 @@ def generate_video_presentation_task(
Celery task to generate video presentation from source content.
Updates existing video presentation record created by the tool.
"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
try:
- result = loop.run_until_complete(
- _generate_video_presentation(
+ return run_async_celery_task(
+ lambda: _generate_video_presentation(
video_presentation_id,
source_content,
search_space_id,
user_prompt,
)
)
- loop.run_until_complete(loop.shutdown_asyncgens())
- return result
except Exception as e:
logger.error(f"Error generating video presentation: {e!s}")
- loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
+ # Mark FAILED in a fresh loop — the previous loop is closed.
+ # Swallow secondary failures; the row will simply stay in
+ # GENERATING and be flushed by the periodic stale cleanup.
+ try:
+ run_async_celery_task(
+ lambda: _mark_video_presentation_failed(video_presentation_id)
+ )
+ except Exception:
+ logger.exception(
+ "Failed to mark video presentation %s as failed",
+ video_presentation_id,
+ )
return {"status": "failed", "video_presentation_id": video_presentation_id}
- finally:
- asyncio.set_event_loop(None)
- loop.close()
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
@@ -150,11 +162,12 @@ async def _generate_video_presentation(
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
usage_type="video_presentation_generation",
- thread_id=video_pres.thread_id,
call_details={
"video_presentation_id": video_pres.id,
"title": video_pres.title,
+ "thread_id": video_pres.thread_id,
},
+ billable_session_factory=_celery_billable_session,
):
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
@@ -175,6 +188,18 @@ async def _generate_video_presentation(
"video_presentation_id": video_pres.id,
"reason": "premium_quota_exhausted",
}
+ except BillingSettlementError:
+ logger.exception(
+ "VideoPresentation %s: premium billing settlement failed",
+ video_pres.id,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "billing_settlement_failed",
+ }
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])
@@ -205,7 +230,14 @@ async def _generate_video_presentation(
video_pres.slides = serializable_slides
video_pres.scene_codes = serializable_scene_codes
video_pres.status = VideoPresentationStatus.READY
+ logger.info(
+ "VideoPresentation %s: committing READY slides=%d scene_codes=%d",
+ video_pres.id,
+ len(serializable_slides),
+ len(serializable_scene_codes),
+ )
await session.commit()
+ logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
logger.info(f"Successfully generated video presentation: {video_pres.id}")
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index 31c0d7d6d..c6ac3311a 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -1506,10 +1506,10 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else "Podcast"
)
- if podcast_status == "processing":
+ if podcast_status in ("pending", "generating", "processing"):
completed_items = [
f"Title: {podcast_title}",
- "Audio generation started",
+ "Podcast generation started",
"Processing in background...",
]
elif podcast_status == "already_generating":
@@ -1518,7 +1518,7 @@ async def _stream_agent_events(
"Podcast already in progress",
"Please wait for it to complete",
]
- elif podcast_status == "error":
+ elif podcast_status in ("failed", "error"):
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
@@ -1528,6 +1528,11 @@ async def _stream_agent_events(
f"Title: {podcast_title}",
f"Error: {error_msg[:50]}",
]
+ elif podcast_status in ("ready", "success"):
+ completed_items = [
+ f"Title: {podcast_title}",
+ "Podcast ready",
+ ]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
@@ -1710,20 +1715,28 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else {"result": tool_output},
)
- if (
- isinstance(tool_output, dict)
- and tool_output.get("status") == "success"
+ if isinstance(tool_output, dict) and tool_output.get("status") in (
+ "pending",
+ "generating",
+ "processing",
+ ):
+ yield streaming_service.format_terminal_info(
+ f"Podcast queued: {tool_output.get('title', 'Podcast')}",
+ "success",
+ )
+ elif isinstance(tool_output, dict) and tool_output.get("status") in (
+ "ready",
+ "success",
):
yield streaming_service.format_terminal_info(
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
"success",
)
- else:
- error_msg = (
- tool_output.get("error", "Unknown error")
- if isinstance(tool_output, dict)
- else "Unknown error"
- )
+ elif isinstance(tool_output, dict) and tool_output.get("status") in (
+ "failed",
+ "error",
+ ):
+ error_msg = tool_output.get("error", "Unknown error")
yield streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}",
"error",
@@ -2292,6 +2305,11 @@ async def stream_new_chat(
)
_t0 = time.perf_counter()
+ # Image-bearing turns force the Auto-pin resolver to filter the
+ # candidate pool to vision-capable cfgs (and force-repin a
+ # text-only existing pin). For explicit selections this flag is
+ # a no-op — the resolver returns the user's chosen id unchanged.
+ _requires_image_input = bool(user_image_data_urls)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
@@ -2300,13 +2318,29 @@ async def stream_new_chat(
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
+ requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
+ # Auto-pin's "no vision-capable cfg" path raises a ValueError
+ # whose message we map to the friendly image-input SSE error
+ # so the user sees the same message regardless of whether
+ # the gate fired in Auto-mode or in the agent_config check
+ # below.
+ error_code = (
+ "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
+ if _requires_image_input and "vision-capable" in str(pin_error)
+ else "SERVER_ERROR"
+ )
+ error_kind = (
+ "user_error"
+ if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
+ else "server_error"
+ )
yield _emit_stream_error(
message=str(pin_error),
- error_kind="server_error",
- error_code="SERVER_ERROR",
+ error_kind=error_kind,
+ error_code=error_code,
)
yield streaming_service.format_done()
return
@@ -2326,6 +2360,50 @@ async def stream_new_chat(
llm_config_id,
)
+ # Capability safety net: a turn carrying user-uploaded images
+ # cannot be routed to a chat config that LiteLLM's authoritative
+ # model map *explicitly* marks as text-only (``supports_vision``
+ # set to False). The check is intentionally narrow — it only
+ # fires when LiteLLM is *certain* the model can't accept image
+ # input. Unknown / unmapped / vision-capable models pass
+ # through. Without this guard a known-text-only model would 404
+ # at the provider with ``"No endpoints found that support image
+ # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
+ # failing here lets us return a friendly message that tells the
+ # user what to change.
+ if user_image_data_urls and agent_config is not None:
+ from app.services.provider_capabilities import (
+ is_known_text_only_chat_model,
+ )
+
+ agent_litellm_params = agent_config.litellm_params or {}
+ agent_base_model = (
+ agent_litellm_params.get("base_model")
+ if isinstance(agent_litellm_params, dict)
+ else None
+ )
+ if is_known_text_only_chat_model(
+ provider=agent_config.provider,
+ model_name=agent_config.model_name,
+ base_model=agent_base_model,
+ custom_provider=agent_config.custom_provider,
+ ):
+ model_label = (
+ agent_config.config_name or agent_config.model_name or "model"
+ )
+ yield _emit_stream_error(
+ message=(
+ f"The selected model ({model_label}) does not support "
+ "image input. Switch to a vision-capable model "
+ "(e.g. GPT-4o, Claude, Gemini) or remove the image "
+ "attachment and try again."
+ ),
+ error_kind="user_error",
+ error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
+ )
+ yield streaming_service.format_done()
+ return
+
# Premium quota reservation for pinned premium model only.
_needs_premium_quota = (
agent_config is not None and user_id and agent_config.is_premium
@@ -2366,6 +2444,7 @@ async def stream_new_chat(
user_id=user_id,
selected_llm_config_id=0,
force_repin_free=True,
+ requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
@@ -2470,6 +2549,7 @@ async def stream_new_chat(
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
+ requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
@@ -2804,6 +2884,7 @@ async def stream_new_chat(
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
+ from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
_turn_accumulator.set(None)
@@ -2824,11 +2905,32 @@ async def stream_new_chat(
model="auto", messages=messages
)
else:
+ # Apply the same ``api_base`` cascade chat / vision /
+ # image-gen call sites use so we never inherit
+ # ``litellm.api_base`` (commonly set by
+ # ``AZURE_OPENAI_ENDPOINT``) when the chat config
+ # itself ships an empty ``api_base``. Without this
+ # the title-gen on an OpenRouter chat config would
+ # 404 against the inherited Azure endpoint — see
+ # ``provider_api_base`` docstring for the same
+ # bug repro on the image-gen / vision paths.
+ raw_model = getattr(llm, "model", "") or ""
+ provider_prefix = (
+ raw_model.split("/", 1)[0] if "/" in raw_model else None
+ )
+ provider_value = (
+ agent_config.provider if agent_config is not None else None
+ )
+ title_api_base = resolve_api_base(
+ provider=provider_value,
+ provider_prefix=provider_prefix,
+ config_api_base=getattr(llm, "api_base", None),
+ )
response = await acompletion(
- model=llm.model,
+ model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
- api_base=getattr(llm, "api_base", None),
+ api_base=title_api_base,
)
usage_info = None
@@ -2953,6 +3055,7 @@ async def stream_new_chat(
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
+ requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py
new file mode 100644
index 000000000..a49d4eab2
--- /dev/null
+++ b/surfsense_backend/scripts/verify_chat_image_capability.py
@@ -0,0 +1,558 @@
+"""End-to-end smoke test for vision / image config wiring.
+
+Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and
+exercises every chat / vision / image-generation config + the OpenRouter
+dynamic catalog. For each config the script:
+
+1. Reports the resolver classification (catalog-allow vs strict-block).
+2. Optionally fires a tiny live API call against the provider:
+ - Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt
+ ``"reply with one word: ok"``.
+ - Vision configs: same, against the dedicated vision router pool.
+ - Image-gen configs: ``litellm.aimage_generation`` with a single tiny
+ prompt and ``n=1``.
+ - OpenRouter integration: samples one chat, one vision, one image-gen
+ model from the dynamically fetched catalog.
+
+Usage::
+
+ python -m scripts.verify_chat_image_capability # capability + connectivity
+ python -m scripts.verify_chat_image_capability --no-live # capability resolver only
+
+The script is meant to be runnable from the repository root or from
+``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the
+end so it's usable as a CI smoke check too.
+
+Live-mode caveat: each successful call costs a small amount of provider
+credit (a few tokens or one tiny generated image per config). The
+default size for image generation is ``1024x1024`` because Azure
+GPT-image deployments reject smaller sizes; OpenRouter image-gen models
+generally accept the same size.
+"""
+
+from __future__ import annotations
+
+import argparse
+import asyncio
+import logging
+import os
+import sys
+import time
+from dataclasses import dataclass, field
+from typing import Any
+
+# Bootstrap the surfsense_backend package on sys.path so the script runs
+# from the repo root or from `surfsense_backend/` interchangeably.
+_HERE = os.path.dirname(os.path.abspath(__file__))
+_BACKEND_ROOT = os.path.dirname(_HERE)
+if _BACKEND_ROOT not in sys.path:
+ sys.path.insert(0, _BACKEND_ROOT)
+
+import litellm # noqa: E402
+
+from app.config import config # noqa: E402
+from app.services.openrouter_integration_service import ( # noqa: E402
+ _OPENROUTER_DYNAMIC_MARKER,
+ OpenRouterIntegrationService,
+)
+from app.services.provider_api_base import resolve_api_base # noqa: E402
+from app.services.provider_capabilities import ( # noqa: E402
+ derive_supports_image_input,
+ is_known_text_only_chat_model,
+)
+
+logging.basicConfig(
+ level=logging.WARNING,
+ format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
+)
+# Quiet down LiteLLM's verbose router/cost logs so the script output is
+# scannable.
+logging.getLogger("LiteLLM").setLevel(logging.ERROR)
+logging.getLogger("litellm").setLevel(logging.ERROR)
+logging.getLogger("httpx").setLevel(logging.ERROR)
+
+# 1x1 transparent PNG — used as the cheapest possible vision payload.
+_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
+_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}"
+
+
+# ---------------------------------------------------------------------------
+# Result accounting
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ProbeResult:
+ label: str
+ surface: str
+ config_id: int | str
+ capability_ok: bool | None = None
+ capability_note: str = ""
+ live_ok: bool | None = None
+ live_note: str = ""
+ duration_s: float = 0.0
+
+
+@dataclass
+class Report:
+ results: list[ProbeResult] = field(default_factory=list)
+
+ def add(self, r: ProbeResult) -> None:
+ self.results.append(r)
+
+ def render(self) -> int:
+ passed = failed = skipped = 0
+ print()
+ print("=" * 92)
+ print(
+ f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes"
+ )
+ print("-" * 92)
+ for r in self.results:
+
+ def _flag(value: bool | None) -> str:
+ if value is None:
+ return "skip"
+ return "ok" if value else "fail"
+
+ cap = _flag(r.capability_ok)
+ live = _flag(r.live_ok)
+ if r.capability_ok is False or r.live_ok is False:
+ failed += 1
+ elif r.capability_ok is None and r.live_ok is None:
+ skipped += 1
+ else:
+ passed += 1
+ print(
+ f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} "
+ f"{r.duration_s:>5.2f}s {r.label}"
+ )
+ if r.capability_note:
+ print(f" cap: {r.capability_note}")
+ if r.live_note:
+ print(f" live: {r.live_note}")
+ print("-" * 92)
+ print(
+ f"Total: {passed} ok / {failed} fail / {skipped} skip "
+ f"(of {len(self.results)} probes)"
+ )
+ print("=" * 92)
+ return failed
+
+
+# ---------------------------------------------------------------------------
+# Capability probes (no network)
+# ---------------------------------------------------------------------------
+
+
+def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
+ """For chat configs the catalog flag is *expected* True (vision-capable
+ pool). The probe reports both the resolver value and the strict
+ safety-net value to surface any drift between them."""
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = (
+ litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
+ )
+ cap = derive_supports_image_input(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+ block = is_known_text_only_chat_model(
+ provider=cfg.get("provider"),
+ model_name=cfg.get("model_name"),
+ base_model=base_model,
+ custom_provider=cfg.get("custom_provider"),
+ )
+ note = f"derive={cap} strict_block={block}"
+ if not cap and not block:
+ # Resolver said False but strict gate is also False — that means
+ # OR modalities published [text] explicitly. Surface it.
+ note += " (OR modality says text-only)"
+ # We accept a True derive *or* (False derive AND False block) as
+ # 'capability ok' — either way, the streaming task will flow through.
+ ok = cap or not block
+ return ok, note
+
+
+def _build_chat_model_string(cfg: dict) -> str:
+ if cfg.get("custom_provider"):
+ return f"{cfg['custom_provider']}/{cfg['model_name']}"
+ from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
+
+ prefix = _PROVIDER_PREFIX_MAP.get(
+ (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
+ )
+ return f"{prefix}/{cfg['model_name']}"
+
+
+# ---------------------------------------------------------------------------
+# Live probes (network calls)
+# ---------------------------------------------------------------------------
+
+
+async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
+ """Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
+ model_string = _build_chat_model_string(cfg)
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=model_string.split("/", 1)[0],
+ config_api_base=cfg.get("api_base") or None,
+ )
+ kwargs: dict[str, Any] = {
+ "model": model_string,
+ "api_key": cfg.get("api_key"),
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "reply with one word: ok"},
+ {
+ "type": "image_url",
+ "image_url": {"url": _TINY_PNG_DATA_URL},
+ },
+ ],
+ }
+ ],
+ "max_tokens": 16,
+ "timeout": 60,
+ }
+ if api_base:
+ kwargs["api_base"] = api_base
+ if cfg.get("litellm_params"):
+ # Strip pricing keys — they're tracking-only and confuse some
+ # provider validators (e.g. azure/openai reject unknown kwargs
+ # in strict mode).
+ merged = {
+ k: v
+ for k, v in dict(cfg["litellm_params"]).items()
+ if k
+ not in {
+ "input_cost_per_token",
+ "output_cost_per_token",
+ "input_cost_per_pixel",
+ "output_cost_per_pixel",
+ }
+ }
+ kwargs.update(merged)
+ try:
+ resp = await litellm.acompletion(**kwargs)
+ except Exception as exc:
+ return False, f"{type(exc).__name__}: {exc}"
+ text = resp.choices[0].message.content if resp.choices else ""
+ return True, f"got reply ({(text or '').strip()[:40]!r})"
+
+
+# Gemini image models occasionally return zero-length ``data`` for the
+# minimal "red dot on white" prompt (provider-side safety / empty-output
+# quirk reproducible against ``google/gemini-2.5-flash-image`` even when
+# the request itself succeeds). Use a more naturalistic prompt and
+# retry once with a different one before giving up.
+_IMAGE_GEN_PROMPTS: tuple[str, ...] = (
+ "A simple icon of a coffee cup, flat illustration",
+ "A small green leaf on a white background",
+)
+
+
+async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
+ """Generate one tiny image to verify the deployment is reachable."""
+ from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
+
+ if cfg.get("custom_provider"):
+ prefix = cfg["custom_provider"]
+ else:
+ prefix = _PROVIDER_PREFIX_MAP.get(
+ (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
+ )
+ model_string = f"{prefix}/{cfg['model_name']}"
+ api_base = resolve_api_base(
+ provider=cfg.get("provider"),
+ provider_prefix=prefix,
+ config_api_base=cfg.get("api_base") or None,
+ )
+ base_kwargs: dict[str, Any] = {
+ "model": model_string,
+ "api_key": cfg.get("api_key"),
+ "n": 1,
+ "size": "1024x1024",
+ "timeout": 120,
+ }
+ if api_base:
+ base_kwargs["api_base"] = api_base
+ if cfg.get("api_version"):
+ base_kwargs["api_version"] = cfg["api_version"]
+ if cfg.get("litellm_params"):
+ base_kwargs.update(
+ {
+ k: v
+ for k, v in dict(cfg["litellm_params"]).items()
+ if k
+ not in {
+ "input_cost_per_token",
+ "output_cost_per_token",
+ "input_cost_per_pixel",
+ "output_cost_per_pixel",
+ }
+ }
+ )
+
+ last_note = ""
+ for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1):
+ try:
+ resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs)
+ except Exception as exc:
+ last_note = f"{type(exc).__name__}: {exc}"
+ continue
+ data_count = len(getattr(resp, "data", None) or [])
+ if data_count > 0:
+ return True, (
+ f"received {data_count} image(s) on attempt {attempt} "
+ f"(prompt={prompt!r})"
+ )
+ last_note = (
+ f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})"
+ )
+ return False, last_note
+
+
+# ---------------------------------------------------------------------------
+# Probe drivers
+# ---------------------------------------------------------------------------
+
+
+def _is_or_dynamic(cfg: dict) -> bool:
+ return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER))
+
+
+async def probe_chat_configs(report: Report, *, live: bool) -> None:
+ print("\n[chat configs from global_llm_configs (YAML-static)]")
+ for cfg in config.GLOBAL_LLM_CONFIGS:
+ # Skip OR dynamic entries here — handled in the OR section so
+ # the YAML / OR split stays clear in the report.
+ if _is_or_dynamic(cfg):
+ continue
+ result = ProbeResult(
+ label=str(cfg.get("name") or cfg.get("model_name")),
+ surface="chat-yaml",
+ config_id=cfg.get("id"),
+ )
+ cap_ok, cap_note = _probe_chat_capability(cfg)
+ result.capability_ok = cap_ok
+ result.capability_note = cap_note
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await _live_chat_image_call(cfg)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+async def probe_vision_configs(report: Report, *, live: bool) -> None:
+ print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
+ for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
+ if _is_or_dynamic(cfg):
+ continue
+ result = ProbeResult(
+ label=str(cfg.get("name") or cfg.get("model_name")),
+ surface="vision",
+ config_id=cfg.get("id"),
+ )
+ # For vision configs, capability is implied — they're in the
+ # dedicated vision pool. Run the same resolver to flag any
+ # surprise disagreement.
+ cap_ok, cap_note = _probe_chat_capability(cfg)
+ result.capability_ok = cap_ok
+ result.capability_note = cap_note
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await _live_chat_image_call(cfg)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
+ print(
+ "\n[image generation configs from global_image_generation_configs (YAML-static)]"
+ )
+ for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
+ if _is_or_dynamic(cfg):
+ continue
+ result = ProbeResult(
+ label=str(cfg.get("name") or cfg.get("model_name")),
+ surface="image-gen",
+ config_id=cfg.get("id"),
+ )
+ # Image gen configs don't have a "supports_image_input" flag;
+ # the catalog tracks output, not input. Mark capability as None
+ # (skip) for the report.
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await _live_image_gen_call(cfg)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
+ """Sample one chat (vision-capable), one vision, one image-gen model
+ from the live OpenRouter catalogue. Doesn't iterate the full pool
+ (would be hundreds of probes); just validates the integration end-
+ to-end on a representative model from each surface."""
+ print("\n[OpenRouter integration: sampled probes]")
+ settings = config.OPENROUTER_INTEGRATION_SETTINGS
+ if not settings:
+ report.add(
+ ProbeResult(
+ label="OpenRouter integration",
+ surface="openrouter",
+ config_id="settings",
+ capability_ok=None,
+ capability_note="openrouter_integration disabled in YAML — skipping",
+ live_ok=None,
+ )
+ )
+ return
+
+ service = OpenRouterIntegrationService.get_instance()
+ or_chat = [
+ c
+ for c in config.GLOBAL_LLM_CONFIGS
+ if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
+ ]
+ or_vision = [
+ c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
+ ]
+ or_image_gen = [
+ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
+ ]
+
+ # Pick one representative per provider family per surface so a single
+ # broken vendor (e.g. Anthropic key revoked, Google quota exceeded)
+ # surfaces independently of the others. Each needle matches the
+ # OpenRouter ``model_name`` prefix; the first match wins.
+ def _pick_first(pool: list[dict], needle: str) -> dict | None:
+ for c in pool:
+ if (c.get("model_name") or "").lower().startswith(needle):
+ return c
+ return None
+
+ chat_picks = [
+ ("or-chat", _pick_first(or_chat, "openai/gpt-4o")),
+ ("or-chat", _pick_first(or_chat, "anthropic/claude")),
+ ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
+ ]
+ vision_picks = [
+ ("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
+ ("or-vision", _pick_first(or_vision, "anthropic/claude")),
+ ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
+ ]
+ image_picks = [
+ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
+ # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
+ # / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match
+ # the actual prefix.
+ ("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")),
+ ]
+
+ print(
+ f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
+ f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
+ )
+
+ for surface, picked in chat_picks + vision_picks + image_picks:
+ if not picked:
+ report.add(
+ ProbeResult(
+ label=f"",
+ surface=surface,
+ config_id="-",
+ capability_ok=None,
+ capability_note="no candidate found in OR catalog",
+ )
+ )
+ continue
+ runner = (
+ _live_image_gen_call if surface == "or-image" else _live_chat_image_call
+ )
+ result = ProbeResult(
+ label=str(picked.get("model_name")),
+ surface=surface,
+ config_id=picked.get("id"),
+ )
+ if surface != "or-image":
+ cap_ok, cap_note = _probe_chat_capability(picked)
+ result.capability_ok = cap_ok
+ result.capability_note = cap_note
+ if live:
+ t0 = time.perf_counter()
+ ok, note = await runner(picked)
+ result.live_ok = ok
+ result.live_note = note
+ result.duration_s = time.perf_counter() - t0
+ report.add(result)
+
+
+# ---------------------------------------------------------------------------
+# Entry point
+# ---------------------------------------------------------------------------
+
+
+async def main(args: argparse.Namespace) -> int:
+ print("Loaded global configs:")
+ print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
+ print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
+ print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
+ print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
+
+ # Initialize the OpenRouter integration so the catalog is populated
+ # (this is what main.py does at startup). It's idempotent.
+ if config.OPENROUTER_INTEGRATION_SETTINGS:
+ try:
+ from app.config import initialize_openrouter_integration
+
+ initialize_openrouter_integration()
+ except Exception as exc:
+ print(f" WARNING: OpenRouter integration init failed: {exc}")
+
+ print(
+ f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}"
+ )
+
+ report = Report()
+ if not args.skip_chat:
+ await probe_chat_configs(report, live=args.live)
+ if not args.skip_vision:
+ await probe_vision_configs(report, live=args.live)
+ if not args.skip_image_gen:
+ await probe_image_gen_configs(report, live=args.live)
+ if not args.skip_openrouter:
+ await probe_openrouter_catalog(report, live=args.live)
+
+ failed = report.render()
+ return 1 if failed else 0
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--no-live",
+ dest="live",
+ action="store_false",
+ help="Skip live API calls — capability resolver only.",
+ )
+ parser.set_defaults(live=True)
+ parser.add_argument("--skip-chat", action="store_true")
+ parser.add_argument("--skip-vision", action="store_true")
+ parser.add_argument("--skip-image-gen", action="store_true")
+ parser.add_argument("--skip-openrouter", action="store_true")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = _parse_args()
+ sys.exit(asyncio.run(main(args)))
diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py
new file mode 100644
index 000000000..c9f18d77d
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py
@@ -0,0 +1,110 @@
+"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
+endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
+
+There is no DB column for ``supports_image_input`` on
+``NewLLMConfig`` — the value is resolved at the API boundary by
+``derive_supports_image_input`` so the new-chat selector / streaming
+task can read the same field shape regardless of source (BYOK vs YAML
+vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
+user out of their own model choice.
+"""
+
+from __future__ import annotations
+
+from datetime import UTC, datetime
+from types import SimpleNamespace
+from uuid import uuid4
+
+import pytest
+
+from app.db import LiteLLMProvider
+from app.routes import new_llm_config_routes
+
+pytestmark = pytest.mark.unit
+
+
+def _byok_row(
+ *,
+ id_: int,
+ model_name: str,
+ base_model: str | None = None,
+ provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
+ custom_provider: str | None = None,
+) -> object:
+ """Mimic the SQLAlchemy row's attribute surface; ``model_validate``
+ walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
+
+ ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
+ enum validator accepts it — same as the ORM row would carry."""
+ return SimpleNamespace(
+ id=id_,
+ name=f"BYOK-{id_}",
+ description=None,
+ provider=provider,
+ custom_provider=custom_provider,
+ model_name=model_name,
+ api_key="sk-byok",
+ api_base=None,
+ litellm_params={"base_model": base_model} if base_model else None,
+ system_instructions="",
+ use_default_system_instructions=True,
+ citations_enabled=True,
+ created_at=datetime.now(tz=UTC),
+ search_space_id=42,
+ user_id=uuid4(),
+ )
+
+
+def test_serialize_byok_known_vision_model_resolves_true():
+ """The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
+ True. The serialized row carries that value through to the
+ ``NewLLMConfigRead`` schema."""
+ row = _byok_row(id_=1, model_name="gpt-4o")
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+
+ assert serialized.supports_image_input is True
+ assert serialized.id == 1
+ assert serialized.model_name == "gpt-4o"
+
+
+def test_serialize_byok_unknown_model_default_allows():
+ """Unknown / unmapped: default-allow. The streaming-task safety net
+ is the actual block, and it requires LiteLLM to *explicitly* say
+ text-only — so a brand new BYOK model should not be pre-judged."""
+ row = _byok_row(
+ id_=2,
+ model_name="brand-new-model-x9-unmapped",
+ provider=LiteLLMProvider.CUSTOM,
+ custom_provider="brand_new_proxy",
+ )
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+
+ assert serialized.supports_image_input is True
+
+
+def test_serialize_byok_uses_base_model_when_present():
+ """Azure-style: ``model_name`` is the deployment id, ``base_model``
+ inside ``litellm_params`` is the canonical sku LiteLLM knows. The
+ helper must consult ``base_model`` first or unrecognised deployment
+ ids would shadow the real capability."""
+ row = _byok_row(
+ id_=3,
+ model_name="my-azure-deployment-id-no-litellm-knows-this",
+ base_model="gpt-4o",
+ provider=LiteLLMProvider.AZURE_OPENAI,
+ )
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+
+ assert serialized.supports_image_input is True
+
+
+def test_serialize_byok_returns_pydantic_read_model():
+ """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
+ the schema additions are guaranteed to be present in the API
+ surface. This guards against a future regression where someone
+ deletes the augmentation step and falls back to ORM passthrough."""
+ from app.schemas import NewLLMConfigRead
+
+ row = _byok_row(id_=4, model_name="gpt-4o")
+ serialized = new_llm_config_routes._serialize_byok_config(row)
+ assert isinstance(serialized, NewLLMConfigRead)
diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py
new file mode 100644
index 000000000..2b6c76485
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py
@@ -0,0 +1,184 @@
+"""Unit tests for ``is_premium`` derivation on the global image-gen and
+vision-LLM list endpoints.
+
+Chat globals (``GET /global-llm-configs``) already emit
+``is_premium = (billing_tier == "premium")``. Image and vision did not,
+which made the new-chat ``model-selector`` render the Free/Premium badge
+on the Chat tab but skip it on the Image and Vision tabs (the selector
+keys its badge logic off ``is_premium``). These tests pin parity:
+
+* YAML free entry → ``is_premium=False``
+* YAML premium entry → ``is_premium=True``
+* OpenRouter dynamic premium entry → ``is_premium=True``
+* Auto stub (always emitted when at least one config is present)
+ → ``is_premium=False``
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+_IMAGE_FIXTURE: list[dict] = [
+ {
+ "id": -1,
+ "name": "DALL-E 3",
+ "provider": "OPENAI",
+ "model_name": "dall-e-3",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ },
+ {
+ "id": -2,
+ "name": "GPT-Image 1 (premium)",
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "api_key": "sk-test",
+ "billing_tier": "premium",
+ },
+ {
+ "id": -20_001,
+ "name": "google/gemini-2.5-flash-image (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash-image",
+ "api_key": "sk-or-test",
+ "api_base": "https://openrouter.ai/api/v1",
+ "billing_tier": "premium",
+ },
+]
+
+
+_VISION_FIXTURE: list[dict] = [
+ {
+ "id": -1,
+ "name": "GPT-4o Vision",
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ },
+ {
+ "id": -2,
+ "name": "Claude 3.5 Sonnet (premium)",
+ "provider": "ANTHROPIC",
+ "model_name": "claude-3-5-sonnet",
+ "api_key": "sk-ant-test",
+ "billing_tier": "premium",
+ },
+ {
+ "id": -30_001,
+ "name": "openai/gpt-4o (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "api_key": "sk-or-test",
+ "api_base": "https://openrouter.ai/api/v1",
+ "billing_tier": "premium",
+ },
+]
+
+
+# =============================================================================
+# Image generation
+# =============================================================================
+
+
+@pytest.mark.asyncio
+async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
+ """Each emitted config must carry ``is_premium`` derived server-side
+ from ``billing_tier``. The Auto stub is always free.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
+ )
+
+ payload = await image_generation_routes.get_global_image_gen_configs(user=None)
+
+ by_id = {c["id"]: c for c in payload}
+
+ # Auto stub is always emitted when at least one global config exists,
+ # and it must always declare itself free (Auto-mode billing-tier
+ # surfacing is a separate follow-up).
+ assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
+ assert by_id[0]["is_premium"] is False
+ assert by_id[0]["billing_tier"] == "free"
+
+ # YAML free entry — ``is_premium=False``
+ assert by_id[-1]["is_premium"] is False
+ assert by_id[-1]["billing_tier"] == "free"
+
+ # YAML premium entry — ``is_premium=True``
+ assert by_id[-2]["is_premium"] is True
+ assert by_id[-2]["billing_tier"] == "premium"
+
+ # OpenRouter dynamic premium entry — same field, same derivation
+ assert by_id[-20_001]["is_premium"] is True
+ assert by_id[-20_001]["billing_tier"] == "premium"
+
+ # Every emitted dict (including Auto) must have the field — never missing.
+ for cfg in payload:
+ assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
+ assert isinstance(cfg["is_premium"], bool)
+
+
+@pytest.mark.asyncio
+async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
+ """When there are no global configs at all, the endpoint emits an
+ empty list (no Auto stub) — Auto mode would have nothing to route to.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
+ payload = await image_generation_routes.get_global_image_gen_configs(user=None)
+ assert payload == []
+
+
+# =============================================================================
+# Vision LLM
+# =============================================================================
+
+
+@pytest.mark.asyncio
+async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
+ from app.config import config
+ from app.routes import vision_llm_routes
+
+ monkeypatch.setattr(
+ config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
+ )
+
+ payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
+
+ by_id = {c["id"]: c for c in payload}
+
+ assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
+ assert by_id[0]["is_premium"] is False
+ assert by_id[0]["billing_tier"] == "free"
+
+ assert by_id[-1]["is_premium"] is False
+ assert by_id[-1]["billing_tier"] == "free"
+
+ assert by_id[-2]["is_premium"] is True
+ assert by_id[-2]["billing_tier"] == "premium"
+
+ assert by_id[-30_001]["is_premium"] is True
+ assert by_id[-30_001]["billing_tier"] == "premium"
+
+ for cfg in payload:
+ assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
+ assert isinstance(cfg["is_premium"], bool)
+
+
+@pytest.mark.asyncio
+async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
+ from app.config import config
+ from app.routes import vision_llm_routes
+
+ monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
+ payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
+ assert payload == []
diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py
new file mode 100644
index 000000000..b47d9134b
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py
@@ -0,0 +1,106 @@
+"""Unit tests for ``supports_image_input`` derivation on the chat global
+config endpoint (``GET /global-new-llm-configs``).
+
+Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
+
+1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
+ loader for operator overrides, or by the OpenRouter integration from
+ ``architecture.input_modalities``) — wins.
+2. ``derive_supports_image_input`` helper — default-allow on unknown
+ models, only False when LiteLLM / OR modalities are definitive.
+
+The flag is purely informational at the API boundary. The streaming
+task safety net (``is_known_text_only_chat_model``) is the actual block,
+and it requires LiteLLM to *explicitly* mark the model as text-only.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+_FIXTURE: list[dict] = [
+ {
+ "id": -1,
+ "name": "GPT-4o (explicit true)",
+ "description": "vision-capable, explicit YAML override",
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ "supports_image_input": True,
+ },
+ {
+ "id": -2,
+ "name": "DeepSeek V3 (explicit false)",
+ "description": "OpenRouter dynamic — modality-derived false",
+ "provider": "OPENROUTER",
+ "model_name": "deepseek/deepseek-v3.2-exp",
+ "api_key": "sk-or-test",
+ "api_base": "https://openrouter.ai/api/v1",
+ "billing_tier": "free",
+ "supports_image_input": False,
+ },
+ {
+ "id": -10_010,
+ "name": "Unannotated GPT-4o",
+ "description": "no flag set — resolver should derive True via LiteLLM",
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ # supports_image_input intentionally absent
+ },
+ {
+ "id": -10_011,
+ "name": "Unannotated unknown model",
+ "description": "unmapped — default-allow True",
+ "provider": "CUSTOM",
+ "custom_provider": "brand_new_proxy",
+ "model_name": "brand-new-model-x9",
+ "api_key": "sk-test",
+ "billing_tier": "free",
+ },
+]
+
+
+@pytest.mark.asyncio
+async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
+ """Each emitted chat config carries ``supports_image_input`` as a
+ bool. Explicit values win; unannotated entries are resolved via the
+ helper (default-allow True)."""
+ from app.config import config
+ from app.routes import new_llm_config_routes
+
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
+
+ payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
+ by_id = {c["id"]: c for c in payload}
+
+ # Auto stub: optimistic True so the user can keep Auto selected with
+ # vision-capable deployments somewhere in the pool.
+ assert 0 in by_id, "Auto stub should be emitted when configs exist"
+ assert by_id[0]["supports_image_input"] is True
+ assert by_id[0]["is_auto_mode"] is True
+
+ # Explicit True is preserved.
+ assert by_id[-1]["supports_image_input"] is True
+
+ # Explicit False is preserved (the exact failure mode the safety net
+ # guards against — DeepSeek V3 over OpenRouter would 404 with "No
+ # endpoints found that support image input").
+ assert by_id[-2]["supports_image_input"] is False
+
+ # Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
+ assert by_id[-10_010]["supports_image_input"] is True
+
+ # Unknown / unmapped model: default-allow rather than pre-judge.
+ assert by_id[-10_011]["supports_image_input"] is True
+
+ for cfg in payload:
+ assert "supports_image_input" in cfg, (
+ f"supports_image_input missing from {cfg.get('id')}"
+ )
+ assert isinstance(cfg["supports_image_input"], bool)
diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py
new file mode 100644
index 000000000..0e19b80e4
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py
@@ -0,0 +1,286 @@
+"""Image-aware extension of the Auto-pin resolver.
+
+When the current chat turn carries an ``image_url`` block, the pin
+resolver must:
+
+1. Filter the candidate pool to vision-capable cfgs so a freshly
+ selected pin can never be text-only.
+2. Treat any existing pin whose capability is False as invalid (force
+ re-pin), even when it would otherwise be reused as the thread's
+ stable model.
+3. Raise ``ValueError`` (mapped to the friendly
+ ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
+ task) when no vision-capable cfg is available — instead of silently
+ pinning text-only and 404-ing at the provider.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+import pytest
+
+from app.services.auto_model_pin_service import (
+ clear_healthy,
+ clear_runtime_cooldown,
+ resolve_or_get_pinned_llm_config_id,
+)
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.fixture(autouse=True)
+def _reset_caches():
+ clear_runtime_cooldown()
+ clear_healthy()
+ yield
+ clear_runtime_cooldown()
+ clear_healthy()
+
+
+@dataclass
+class _FakeQuotaResult:
+ allowed: bool
+
+
+class _FakeExecResult:
+ def __init__(self, thread):
+ self._thread = thread
+
+ def unique(self):
+ return self
+
+ def scalar_one_or_none(self):
+ return self._thread
+
+
+class _FakeSession:
+ def __init__(self, thread):
+ self.thread = thread
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self.thread)
+
+ async def commit(self):
+ self.commit_count += 1
+
+
+def _thread(*, pinned: int | None = None):
+ return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
+
+
+def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
+ return {
+ "id": id_,
+ "provider": "OPENAI",
+ "model_name": f"vision-{id_}",
+ "api_key": "k",
+ "billing_tier": tier,
+ "supports_image_input": True,
+ "auto_pin_tier": "A",
+ "quality_score": quality,
+ }
+
+
+def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
+ return {
+ "id": id_,
+ "provider": "OPENAI",
+ "model_name": f"text-{id_}",
+ "api_key": "k",
+ "billing_tier": tier,
+ # Higher quality than the vision cfgs — so a bug that ignores
+ # the image flag would surface as the resolver picking this one.
+ "supports_image_input": False,
+ "auto_pin_tier": "A",
+ "quality_score": quality,
+ }
+
+
+async def _premium_allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+
+@pytest.mark.asyncio
+async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _vision_cfg(-2)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+ assert result.resolved_llm_config_id == -2
+ # The thread should be pinned to the vision cfg even though the
+ # text-only cfg has a higher quality score.
+ assert session.thread.pinned_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
+ """An existing text-only pin must be invalidated when the next turn
+ requires image input. The non-image path would happily reuse it."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned=-1))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _vision_cfg(-2)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is False
+ assert session.thread.pinned_llm_config_id == -2
+
+
+@pytest.mark.asyncio
+async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
+ """If the thread is already pinned to a vision-capable cfg, reuse it
+ — same as the non-image path. Image-aware filtering must not force
+ spurious re-pins."""
+ from app.config import config
+
+ session = _FakeSession(_thread(pinned=-2))
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+ assert result.resolved_llm_config_id == -2
+ assert result.from_existing_pin is True
+
+
+@pytest.mark.asyncio
+async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
+ """The friendly-error path: no vision-capable cfg in the pool -> raise
+ ``ValueError`` whose message contains ``vision-capable`` so the
+ streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1), _text_only_cfg(-2)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ with pytest.raises(ValueError, match="vision-capable"):
+ await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+
+
+@pytest.mark.asyncio
+async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
+ """Regression guard: the image flag must default False and not affect
+ a normal text-only turn — text-only cfgs remain selectable."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [_text_only_cfg(-1)],
+ )
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+
+
+@pytest.mark.asyncio
+async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
+ """A YAML cfg that omits ``supports_image_input`` falls through to
+ ``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
+ that returns True, so the cfg should be a valid candidate."""
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ cfg_unannotated_vision = {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-4o", # known vision model in LiteLLM map
+ "api_key": "k",
+ "billing_tier": "free",
+ "auto_pin_tier": "A",
+ "quality_score": 80,
+ # NOTE: no supports_image_input key
+ }
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _premium_allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id=None,
+ selected_llm_config_id=0,
+ requires_image_input=True,
+ )
+ assert result.resolved_llm_config_id == -2
diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py
index 86de5f23d..c820724ed 100644
--- a/surfsense_backend/tests/unit/services/test_billable_call.py
+++ b/surfsense_backend/tests/unit/services/test_billable_call.py
@@ -15,6 +15,7 @@ vision LLM extraction:
from __future__ import annotations
+import asyncio
import contextlib
from typing import Any
from uuid import uuid4
@@ -57,6 +58,9 @@ class _FakeSession:
async def commit(self) -> None:
self.committed = True
+ async def rollback(self) -> None:
+ pass
+
async def close(self) -> None:
pass
@@ -71,7 +75,9 @@ async def _fake_shielded_session():
_SESSIONS_USED: list[_FakeSession] = []
-def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
+def _patch_isolation_layer(
+ monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
+):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
@@ -91,6 +97,8 @@ def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None)
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
+ if finalize_exc is not None:
+ raise finalize_exc
finalize_calls.append(
{
"user_id": user_id,
@@ -343,6 +351,125 @@ async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
assert spies["reserve"][0]["reserve_micros"] == 12_345
+@pytest.mark.asyncio
+async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
+ from app.services.billable_calls import BillingSettlementError, billable_call
+
+ class _FinalizeError(RuntimeError):
+ pass
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(allowed=True),
+ finalize_exc=_FinalizeError("db finalize failed"),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(BillingSettlementError):
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ) as acc:
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert len(spies["release"]) == 1
+ assert spies["record"] == []
+
+
+@pytest.mark.asyncio
+async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ class _HangingCommitSession(_FakeSession):
+ async def commit(self) -> None:
+ await asyncio.sleep(60)
+
+ @contextlib.asynccontextmanager
+ async def _hanging_session_factory():
+ s = _HangingCommitSession()
+ _SESSIONS_USED.append(s)
+ yield s
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ billable_session_factory=_hanging_session_factory,
+ audit_timeout_seconds=0.01,
+ ) as acc:
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert len(spies["finalize"]) == 1
+ assert len(spies["record"]) == 1
+ assert spies["release"] == []
+
+
+@pytest.mark.asyncio
+async def test_free_audit_failure_is_best_effort(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+
+ async def _failing_record(_session, **_kwargs):
+ raise RuntimeError("audit insert failed")
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.record_token_usage",
+ _failing_record,
+ raising=False,
+ )
+
+ async with billable_call(
+ user_id=uuid4(),
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openai/gpt-image-1",
+ usage_type="image_generation",
+ audit_timeout_seconds=0.01,
+ ) as acc:
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=37_000,
+ call_kind="image_generation",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+
+
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@@ -387,7 +514,7 @@ async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
- assert row["thread_id"] == 99
+ assert row["thread_id"] is None
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py
new file mode 100644
index 000000000..9d5fdb190
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py
@@ -0,0 +1,177 @@
+"""Defense-in-depth: image-gen call sites must not let an empty
+``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
+
+The bug repro: an OpenRouter image-gen config ships
+``api_base=""``. The pre-fix call site in
+``image_generation_routes._execute_image_generation`` did
+``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
+silently dropped the empty string. LiteLLM then fell back to
+``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
+and OpenRouter's ``image_generation/transformation`` appended
+``/chat/completions`` to it → 404 ``Resource not found``.
+
+This test pins the post-fix behaviour: with an empty ``api_base`` in
+the config, the call site MUST set ``api_base`` to OpenRouter's public
+URL instead of leaving it unset.
+"""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
+ """The global-config branch (``config_id < 0``) of
+ ``_execute_image_generation`` must apply the resolver and pin
+ ``api_base`` to OpenRouter when the config ships an empty string.
+ """
+ from app.routes import image_generation_routes
+
+ cfg = {
+ "id": -20_001,
+ "name": "GPT Image 1 (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-image-1",
+ "api_key": "sk-or-test",
+ "api_base": "", # the original bug shape
+ "api_version": None,
+ "litellm_params": {},
+ }
+
+ captured: dict = {}
+
+ async def fake_aimage_generation(**kwargs):
+ captured.update(kwargs)
+ return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
+
+ image_gen = MagicMock()
+ image_gen.image_generation_config_id = cfg["id"]
+ image_gen.prompt = "test"
+ image_gen.n = 1
+ image_gen.quality = None
+ image_gen.size = None
+ image_gen.style = None
+ image_gen.response_format = None
+ image_gen.model = None
+
+ search_space = MagicMock()
+ search_space.image_generation_config_id = cfg["id"]
+ session = MagicMock()
+
+ with (
+ patch.object(
+ image_generation_routes,
+ "_get_global_image_gen_config",
+ return_value=cfg,
+ ),
+ patch.object(
+ image_generation_routes,
+ "aimage_generation",
+ side_effect=fake_aimage_generation,
+ ),
+ ):
+ await image_generation_routes._execute_image_generation(
+ session=session, image_gen=image_gen, search_space=search_space
+ )
+
+ # The whole point of the fix: even with empty ``api_base`` in the
+ # config, we forward OpenRouter's public URL so the call doesn't
+ # inherit an Azure endpoint.
+ assert captured.get("api_base") == "https://openrouter.ai/api/v1"
+ assert captured["model"] == "openrouter/openai/gpt-image-1"
+
+
+@pytest.mark.asyncio
+async def test_generate_image_tool_global_sets_api_base_when_config_empty():
+ """Same defense at the agent tool entry point — both surfaces share
+ the same OpenRouter config payloads."""
+ from app.agents.new_chat.tools import generate_image as gi_module
+
+ cfg = {
+ "id": -20_001,
+ "name": "GPT Image 1 (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-image-1",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ "api_version": None,
+ "litellm_params": {},
+ }
+
+ captured: dict = {}
+
+ async def fake_aimage_generation(**kwargs):
+ captured.update(kwargs)
+ response = MagicMock()
+ response.model_dump.return_value = {
+ "data": [{"url": "https://example.com/x.png"}]
+ }
+ response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
+ return response
+
+ search_space = MagicMock()
+ search_space.id = 1
+ search_space.image_generation_config_id = cfg["id"]
+
+ session_cm = AsyncMock()
+ session = AsyncMock()
+ session_cm.__aenter__.return_value = session
+
+ scalars = MagicMock()
+ scalars.first.return_value = search_space
+ exec_result = MagicMock()
+ exec_result.scalars.return_value = scalars
+ session.execute.return_value = exec_result
+ session.add = MagicMock()
+ session.commit = AsyncMock()
+ session.refresh = AsyncMock()
+
+ # ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
+ async def _refresh(obj):
+ obj.id = 1
+
+ session.refresh.side_effect = _refresh
+
+ with (
+ patch.object(gi_module, "shielded_async_session", return_value=session_cm),
+ patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
+ patch.object(
+ gi_module, "aimage_generation", side_effect=fake_aimage_generation
+ ),
+ patch.object(
+ gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
+ ),
+ ):
+ tool = gi_module.create_generate_image_tool(
+ search_space_id=1, db_session=MagicMock()
+ )
+ await tool.ainvoke({"prompt": "a cat", "n": 1})
+
+ assert captured.get("api_base") == "https://openrouter.ai/api/v1"
+ assert captured["model"] == "openrouter/openai/gpt-image-1"
+
+
+def test_image_gen_router_deployment_sets_api_base_when_config_empty():
+ """The Auto-mode router pool must also resolve ``api_base`` when an
+ OpenRouter config ships an empty string. The deployment dict is fed
+ straight to ``litellm.Router``, so a missing ``api_base`` would
+ leak the same way as the direct call sites.
+ """
+ from app.services.image_gen_router_service import ImageGenRouterService
+
+ deployment = ImageGenRouterService._config_to_deployment(
+ {
+ "model_name": "openai/gpt-image-1",
+ "provider": "OPENROUTER",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ }
+ )
+ assert deployment is not None
+ assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
+ assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"
diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
index b635b4fe8..88fcf2db3 100644
--- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
+++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
@@ -265,6 +265,10 @@ def test_generate_image_gen_configs_filters_by_image_output():
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
+ # Defense-in-depth: emit the OpenRouter base URL at source so a
+ # downstream call site that forgets ``resolve_api_base`` still
+ # doesn't 404 against an inherited Azure endpoint.
+ assert c["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_image_gen_configs_assigns_image_id_offset():
@@ -342,6 +346,10 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output():
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
+ # Defense-in-depth: emit the OpenRouter base URL at source so a
+ # downstream call site that forgets ``resolve_api_base`` still
+ # doesn't inherit an Azure endpoint.
+ assert cfg["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_vision_llm_configs_drops_chat_only_filters():
diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py
new file mode 100644
index 000000000..12cd0a3d5
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py
@@ -0,0 +1,107 @@
+"""Unit tests for the shared ``api_base`` resolver.
+
+The cascade exists so vision and image-gen call sites can't silently
+inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
+when an OpenRouter / Groq / etc. config ships an empty string. See
+``provider_api_base`` module docstring for the original repro
+(OpenRouter image-gen 404-ing against an Azure endpoint).
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.provider_api_base import (
+ PROVIDER_DEFAULT_API_BASE,
+ PROVIDER_KEY_DEFAULT_API_BASE,
+ resolve_api_base,
+)
+
+pytestmark = pytest.mark.unit
+
+
+def test_config_value_wins_over_defaults():
+ """A non-empty config value is always returned verbatim, even when the
+ provider has a default — the operator gets the last word."""
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base="https://my-openrouter-mirror.example.com/v1",
+ )
+ assert result == "https://my-openrouter-mirror.example.com/v1"
+
+
+def test_provider_key_default_when_config_missing():
+ """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
+ base URL — the provider-key map must take precedence over the prefix
+ map so DeepSeek requests don't go to OpenAI."""
+ result = resolve_api_base(
+ provider="DEEPSEEK",
+ provider_prefix="openai",
+ config_api_base=None,
+ )
+ assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
+
+
+def test_provider_prefix_default_when_no_key_default():
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base=None,
+ )
+ assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
+
+
+def test_unknown_provider_returns_none():
+ """When neither map matches we return ``None`` so the caller can let
+ LiteLLM apply its own provider-integration default (Azure deployment
+ URL, custom-provider URL, etc.)."""
+ result = resolve_api_base(
+ provider="SOMETHING_NEW",
+ provider_prefix="something_new",
+ config_api_base=None,
+ )
+ assert result is None
+
+
+def test_empty_string_config_treated_as_missing():
+ """The original bug: OpenRouter dynamic configs ship ``api_base=""``
+ and downstream call sites use ``if cfg.get("api_base"):`` — empty
+ strings are falsy in Python but the cascade has to step in anyway."""
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base="",
+ )
+ assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
+
+
+def test_whitespace_only_config_treated_as_missing():
+ """A config value of ``" "`` is a configuration mistake — treat it
+ as missing instead of forwarding whitespace to LiteLLM (which would
+ almost certainly 404)."""
+ result = resolve_api_base(
+ provider="OPENROUTER",
+ provider_prefix="openrouter",
+ config_api_base=" ",
+ )
+ assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
+
+
+def test_provider_case_insensitive():
+ """Some call sites pass the provider lowercase (DB enum value), others
+ uppercase (YAML key). Both must resolve."""
+ upper = resolve_api_base(
+ provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
+ )
+ lower = resolve_api_base(
+ provider="deepseek", provider_prefix="openai", config_api_base=None
+ )
+ assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
+
+
+def test_all_inputs_none_returns_none():
+ assert (
+ resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
+ is None
+ )
diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py
new file mode 100644
index 000000000..aac88977f
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py
@@ -0,0 +1,244 @@
+"""Unit tests for the shared chat-image capability resolver.
+
+Two resolvers, two intents:
+
+- ``derive_supports_image_input`` — best-effort True for the catalog and
+ selector. Default-allow on unknown / unmapped models. The streaming
+ task safety net never sees this value directly.
+
+- ``is_known_text_only_chat_model`` — strict opt-out for the safety net.
+ Returns True only when LiteLLM's model map *explicitly* sets
+ ``supports_vision=False``. Anything else (missing key, exception,
+ True) returns False so the request flows through to the provider.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.provider_capabilities import (
+ derive_supports_image_input,
+ is_known_text_only_chat_model,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# derive_supports_image_input — OpenRouter modalities path (authoritative)
+# ---------------------------------------------------------------------------
+
+
+def test_or_modalities_with_image_returns_true():
+ assert (
+ derive_supports_image_input(
+ provider="OPENROUTER",
+ model_name="openai/gpt-4o",
+ openrouter_input_modalities=["text", "image"],
+ )
+ is True
+ )
+
+
+def test_or_modalities_text_only_returns_false():
+ assert (
+ derive_supports_image_input(
+ provider="OPENROUTER",
+ model_name="deepseek/deepseek-v3.2-exp",
+ openrouter_input_modalities=["text"],
+ )
+ is False
+ )
+
+
+def test_or_modalities_empty_list_returns_false():
+ """OR explicitly publishing an empty modality list is a definitive
+ 'no inputs at all' signal — treat as False rather than falling back
+ to LiteLLM."""
+ assert (
+ derive_supports_image_input(
+ provider="OPENROUTER",
+ model_name="weird/empty-modalities",
+ openrouter_input_modalities=[],
+ )
+ is False
+ )
+
+
+def test_or_modalities_none_falls_through_to_litellm():
+ """``None`` (missing key) is *not* a definitive signal — fall through
+ to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
+ assert (
+ derive_supports_image_input(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ openrouter_input_modalities=None,
+ )
+ is True
+ )
+
+
+# ---------------------------------------------------------------------------
+# derive_supports_image_input — LiteLLM model-map path
+# ---------------------------------------------------------------------------
+
+
+def test_litellm_known_vision_model_returns_true():
+ assert (
+ derive_supports_image_input(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is True
+ )
+
+
+def test_litellm_base_model_wins_over_model_name():
+ """Azure-style entries pass model_name=deployment_id and put the
+ canonical sku in litellm_params.base_model. The resolver must
+ consult base_model first or the deployment id (which LiteLLM
+ doesn't know) would shadow the real capability."""
+ assert (
+ derive_supports_image_input(
+ provider="AZURE_OPENAI",
+ model_name="my-azure-deployment-id",
+ base_model="gpt-4o",
+ )
+ is True
+ )
+
+
+def test_litellm_unknown_model_default_allows():
+ """Default-allow on unknown — the safety net is the actual block."""
+ assert (
+ derive_supports_image_input(
+ provider="CUSTOM",
+ model_name="brand-new-model-x9-unmapped",
+ custom_provider="brand_new_proxy",
+ )
+ is True
+ )
+
+
+def test_litellm_known_text_only_returns_false():
+ """A model that LiteLLM explicitly knows is text-only resolves to
+ False even via the catalog resolver. ``deepseek-chat`` (the
+ DeepSeek-V3 chat sku) is in the map without supports_vision and
+ LiteLLM's `supports_vision` returns False."""
+ # Sanity: confirm the helper's negative path. We use a small model
+ # known not to support vision per the map.
+ result = derive_supports_image_input(
+ provider="DEEPSEEK",
+ model_name="deepseek-chat",
+ )
+ # We accept either False (LiteLLM said explicit no) or True
+ # (default-allow if the entry isn't mapped on this version) — the
+ # invariant is that the resolver never *raises* on a known-text-only
+ # provider/model. The behaviour-binding assertion lives in
+ # ``test_is_known_text_only_chat_model_explicit_false`` below.
+ assert isinstance(result, bool)
+
+
+# ---------------------------------------------------------------------------
+# is_known_text_only_chat_model — strict opt-out semantics
+# ---------------------------------------------------------------------------
+
+
+def test_is_known_text_only_returns_false_for_vision_model():
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_false_for_unknown_model():
+ """Strict opt-out: missing from the map ≠ text-only. The safety net
+ must NOT fire for an unmapped model — that's the regression we're
+ fixing."""
+ assert (
+ is_known_text_only_chat_model(
+ provider="CUSTOM",
+ model_name="brand-new-model-x9-unmapped",
+ custom_provider="brand_new_proxy",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
+ """LiteLLM's ``get_model_info`` raises freely on parse errors. The
+ helper swallows the exception and returns False so the safety net
+ doesn't fire on a transient lookup failure."""
+ import app.services.provider_capabilities as pc
+
+ def _raise(**_kwargs):
+ raise ValueError("intentional test failure")
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
+ """Stub LiteLLM's ``get_model_info`` to return an explicit False so
+ we exercise the opt-out path deterministically. Using a stub keeps
+ the test stable across LiteLLM map updates."""
+ import app.services.provider_capabilities as pc
+
+ def _info(**_kwargs):
+ return {"supports_vision": False, "max_input_tokens": 8192}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="any-model",
+ )
+ is True
+ )
+
+
+def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
+ import app.services.provider_capabilities as pc
+
+ def _info(**_kwargs):
+ return {"supports_vision": True}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="any-model",
+ )
+ is False
+ )
+
+
+def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
+ """A model entry without ``supports_vision`` at all is treated as
+ 'unknown' — strict opt-out means False."""
+ import app.services.provider_capabilities as pc
+
+ def _info(**_kwargs):
+ return {"max_input_tokens": 8192} # no supports_vision
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="any-model",
+ )
+ is False
+ )
diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py
new file mode 100644
index 000000000..71fdee1c7
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py
@@ -0,0 +1,281 @@
+"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
+
+Capability is sourced from two places, in order of preference:
+
+1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
+ (authoritative — OpenRouter publishes per-model modalities directly).
+2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
+ YAML / BYOK configs that don't carry an explicit operator override.
+
+The catalog default is *True* (conservative-allow): an unknown / unmapped
+model is not pre-judged. The streaming-task safety net
+(``is_known_text_only_chat_model``) is the only place a False actually
+blocks a request — and it requires LiteLLM to *explicitly* mark the model
+as text-only.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.openrouter_integration_service import (
+ _OPENROUTER_DYNAMIC_MARKER,
+ _generate_configs,
+ _supports_image_input,
+)
+
+pytestmark = pytest.mark.unit
+
+
+_SETTINGS_BASE: dict = {
+ "api_key": "sk-or-test",
+ "id_offset": -10_000,
+ "rpm": 200,
+ "tpm": 1_000_000,
+ "free_rpm": 20,
+ "free_tpm": 100_000,
+ "anonymous_enabled_paid": False,
+ "anonymous_enabled_free": True,
+ "quota_reserve_tokens": 4000,
+}
+
+
+# ---------------------------------------------------------------------------
+# _supports_image_input helper (OpenRouter modalities)
+# ---------------------------------------------------------------------------
+
+
+def test_supports_image_input_true_for_multimodal():
+ assert (
+ _supports_image_input(
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ }
+ )
+ is True
+ )
+
+
+def test_supports_image_input_false_for_text_only():
+ """The exact failure mode the safety net guards against — DeepSeek V3
+ is a text-in/text-out model and would 404 if forwarded image_url."""
+ assert (
+ _supports_image_input(
+ {
+ "id": "deepseek/deepseek-v3.2-exp",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ }
+ )
+ is False
+ )
+
+
+def test_supports_image_input_false_when_modalities_missing():
+ """Defensive: missing architecture is treated as text-only at the
+ OpenRouter helper level. The wider catalog resolver
+ (`derive_supports_image_input`) only consults modalities when they
+ are non-empty, otherwise it falls back to LiteLLM."""
+ assert _supports_image_input({"id": "weird/model"}) is False
+ assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
+ assert (
+ _supports_image_input(
+ {"id": "weird/model", "architecture": {"input_modalities": None}}
+ )
+ is False
+ )
+
+
+# ---------------------------------------------------------------------------
+# _generate_configs threads the flag onto every emitted chat config
+# ---------------------------------------------------------------------------
+
+
+def test_generate_configs_emits_supports_image_input():
+ raw = [
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": ["tools"],
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ {
+ "id": "deepseek/deepseek-v3.2-exp",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": ["tools"],
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000003", "completion": "0.000015"},
+ },
+ ]
+ cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
+ by_model = {c["model_name"]: c for c in cfgs}
+
+ gpt = by_model["openai/gpt-4o"]
+ assert gpt["supports_image_input"] is True
+ assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
+
+ deepseek = by_model["deepseek/deepseek-v3.2-exp"]
+ assert deepseek["supports_image_input"] is False
+ assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
+
+
+# ---------------------------------------------------------------------------
+# YAML loader: defer to derive_supports_image_input on unannotated entries
+# ---------------------------------------------------------------------------
+
+
+def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
+ """The regression case: an Azure GPT-5.x YAML entry without a
+ ``supports_image_input`` override should resolve to True via LiteLLM's
+ model map (which says ``supports_vision: true``). Previously this
+ defaulted to False, blocking every image turn for vision-capable
+ YAML configs."""
+ yaml_dir = tmp_path / "app" / "config"
+ yaml_dir.mkdir(parents=True)
+ (yaml_dir / "global_llm_config.yaml").write_text(
+ """
+global_llm_configs:
+ - id: -2
+ name: Azure GPT-4o
+ provider: AZURE_OPENAI
+ model_name: gpt-4o
+ api_key: sk-test
+""",
+ encoding="utf-8",
+ )
+
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+ configs = config_module.load_global_llm_configs()
+ assert len(configs) == 1
+ assert configs[0]["supports_image_input"] is True
+
+
+def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
+ yaml_dir = tmp_path / "app" / "config"
+ yaml_dir.mkdir(parents=True)
+ (yaml_dir / "global_llm_config.yaml").write_text(
+ """
+global_llm_configs:
+ - id: -1
+ name: GPT-4o
+ provider: OPENAI
+ model_name: gpt-4o
+ api_key: sk-test
+ supports_image_input: false
+""",
+ encoding="utf-8",
+ )
+
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+ configs = config_module.load_global_llm_configs()
+ assert len(configs) == 1
+ # Operator override always wins, even against LiteLLM's True.
+ assert configs[0]["supports_image_input"] is False
+
+
+def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
+ """Unknown / unmapped model in YAML: default-allow. The streaming
+ safety net (which requires an explicit-False from LiteLLM) is the
+ only place a real block happens, so we don't lock the user out of
+ a freshly added third-party entry the catalog can't introspect."""
+ yaml_dir = tmp_path / "app" / "config"
+ yaml_dir.mkdir(parents=True)
+ (yaml_dir / "global_llm_config.yaml").write_text(
+ """
+global_llm_configs:
+ - id: -1
+ name: Some Brand New Model
+ provider: CUSTOM
+ custom_provider: brand_new_proxy
+ model_name: brand-new-model-x9
+ api_key: sk-test
+""",
+ encoding="utf-8",
+ )
+
+ from app import config as config_module
+
+ monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
+
+ configs = config_module.load_global_llm_configs()
+ assert len(configs) == 1
+ assert configs[0]["supports_image_input"] is True
+
+
+# ---------------------------------------------------------------------------
+# AgentConfig threads the flag through both YAML and Auto / BYOK
+# ---------------------------------------------------------------------------
+
+
+def test_agent_config_from_yaml_explicit_overrides_resolver():
+ from app.agents.new_chat.llm_config import AgentConfig
+
+ cfg_text_only = AgentConfig.from_yaml_config(
+ {
+ "id": -1,
+ "name": "Text Only Override",
+ "provider": "openai",
+ "model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
+ "api_key": "sk-test",
+ "supports_image_input": False,
+ }
+ )
+ cfg_explicit_vision = AgentConfig.from_yaml_config(
+ {
+ "id": -2,
+ "name": "GPT-4o",
+ "provider": "openai",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ "supports_image_input": True,
+ }
+ )
+ assert cfg_text_only.supports_image_input is False
+ assert cfg_explicit_vision.supports_image_input is True
+
+
+def test_agent_config_from_yaml_unannotated_uses_resolver():
+ """Without an explicit YAML key, AgentConfig defers to the catalog
+ resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
+ from app.agents.new_chat.llm_config import AgentConfig
+
+ cfg = AgentConfig.from_yaml_config(
+ {
+ "id": -1,
+ "name": "GPT-4o (no override)",
+ "provider": "openai",
+ "model_name": "gpt-4o",
+ "api_key": "sk-test",
+ }
+ )
+ assert cfg.supports_image_input is True
+
+
+def test_agent_config_auto_mode_supports_image_input():
+ """Auto routes across the pool. We optimistically allow image input
+ so users can keep their selection on Auto with a vision-capable
+ deployment somewhere in the pool. The router's own `allowed_fails`
+ handles non-vision deployments via fallback."""
+ from app.agents.new_chat.llm_config import AgentConfig
+
+ auto = AgentConfig.from_auto_mode()
+ assert auto.supports_image_input is True
diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py
new file mode 100644
index 000000000..b8ba9d80c
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py
@@ -0,0 +1,89 @@
+"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
+defaults from ``litellm.api_base`` either.
+
+Vision shares the same shape as image-gen — global YAML / OpenRouter
+dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
+call sites would silently drop the empty string and inherit
+``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
+construction so we test the kwargs we hand to it instead.
+"""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_get_vision_llm_global_openrouter_sets_api_base():
+ """Global negative-ID branch: an OpenRouter vision config with
+ ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
+ ``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
+ never silently absent."""
+ from app.services import llm_service
+
+ cfg = {
+ "id": -30_001,
+ "name": "GPT-4o Vision (OpenRouter)",
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ "api_version": None,
+ "litellm_params": {},
+ "billing_tier": "free",
+ }
+
+ search_space = MagicMock()
+ search_space.id = 1
+ search_space.user_id = "user-x"
+ search_space.vision_llm_config_id = cfg["id"]
+
+ session = AsyncMock()
+ scalars = MagicMock()
+ scalars.first.return_value = search_space
+ result = MagicMock()
+ result.scalars.return_value = scalars
+ session.execute.return_value = result
+
+ captured: dict = {}
+
+ class FakeSanitized:
+ def __init__(self, **kwargs):
+ captured.update(kwargs)
+
+ with (
+ patch(
+ "app.services.vision_llm_router_service.get_global_vision_llm_config",
+ return_value=cfg,
+ ),
+ patch(
+ "app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
+ new=FakeSanitized,
+ ),
+ ):
+ await llm_service.get_vision_llm(session=session, search_space_id=1)
+
+ assert captured.get("api_base") == "https://openrouter.ai/api/v1"
+ assert captured["model"] == "openrouter/openai/gpt-4o"
+
+
+def test_vision_router_deployment_sets_api_base_when_config_empty():
+ """Auto-mode vision router: deployments are fed to ``litellm.Router``,
+ so the resolver has to apply at deployment construction time too."""
+ from app.services.vision_llm_router_service import VisionLLMRouterService
+
+ deployment = VisionLLMRouterService._config_to_deployment(
+ {
+ "model_name": "openai/gpt-4o",
+ "provider": "OPENROUTER",
+ "api_key": "sk-or-test",
+ "api_base": "",
+ }
+ )
+ assert deployment is not None
+ assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
+ assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
new file mode 100644
index 000000000..a5bb3f58a
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
@@ -0,0 +1,318 @@
+"""Regression tests for ``run_async_celery_task``.
+
+These tests pin down the production bug observed on 2026-05-02 where
+the video-presentation Celery task hung at ``[billable_call] finalize``
+because the shared ``app.db.engine`` had pooled asyncpg connections
+bound to a *previous* task's now-closed event loop. Reusing such a
+connection on a fresh loop crashes inside ``pool_pre_ping`` with::
+
+ AttributeError: 'NoneType' object has no attribute 'send'
+
+(the proactor is None because the loop is gone) and can hang forever
+inside the asyncpg ``Connection._cancel`` cleanup coroutine.
+
+The fix is ``run_async_celery_task``: a small helper that runs every
+async celery task body inside a fresh event loop and disposes the
+shared engine pool both before (defends against a previous task that
+crashed) and after (releases connections we opened on this loop).
+
+Tests here exercise the helper with a stub engine that records
+``dispose()`` calls and panics if a coroutine produced by one loop is
+awaited on another — mirroring the real asyncpg behaviour.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import gc
+import sys
+from collections.abc import Iterator
+from contextlib import contextmanager
+from unittest.mock import patch
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Stub engine that emulates the asyncpg-on-stale-loop crash
+# ---------------------------------------------------------------------------
+
+
+class _StaleLoopEngine:
+ """Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
+
+ ``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
+ the running event loop id so tests can assert it ran on *each*
+ fresh loop.
+ """
+
+ def __init__(self) -> None:
+ self.dispose_loop_ids: list[int] = []
+
+ async def dispose(self) -> None:
+ loop = asyncio.get_running_loop()
+ self.dispose_loop_ids.append(id(loop))
+
+
+@contextmanager
+def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
+ """Patch ``from app.db import engine as shared_engine`` lookup.
+
+ The helper imports lazily inside the function body, so we have to
+ patch the attribute on the already-loaded ``app.db`` module.
+ """
+ import app.db as app_db
+
+ original = getattr(app_db, "engine", None)
+ app_db.engine = stub # type: ignore[attr-defined]
+ try:
+ yield
+ finally:
+ if original is None:
+ with pytest.raises(AttributeError):
+ _ = app_db.engine
+ else:
+ app_db.engine = original # type: ignore[attr-defined]
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+def test_runner_returns_value_and_disposes_engine_around_call() -> None:
+ """Happy path: the coroutine result is returned, and the shared
+ engine is disposed both before and after the task body runs.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ async def _body() -> str:
+ # Engine should already have been disposed once before we run.
+ assert len(stub.dispose_loop_ids) == 1
+ return "ok"
+
+ with _patch_shared_engine(stub):
+ result = run_async_celery_task(_body)
+
+ assert result == "ok"
+ # Once before the body, once after (in finally).
+ assert len(stub.dispose_loop_ids) == 2
+ # Both disposes ran on the SAME (fresh) loop the task body used.
+ assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
+
+
+def test_runner_creates_fresh_loop_per_invocation() -> None:
+ """Each call must spin its own loop. Without this guarantee a
+ previous task's loop would be reused and the asyncpg-stale-loop
+ crash would never be avoided.
+ """
+ import app.tasks.celery_tasks as celery_tasks_pkg
+
+ stub = _StaleLoopEngine()
+ new_loop_calls = 0
+ closed_loops: list[bool] = []
+
+ real_new_event_loop = asyncio.new_event_loop
+
+ def _counting_new_loop() -> asyncio.AbstractEventLoop:
+ nonlocal new_loop_calls
+ new_loop_calls += 1
+ loop = real_new_event_loop()
+ # Hook close() so we can verify each loop was closed properly
+ # before the next one was created.
+ original_close = loop.close
+
+ def _tracked_close() -> None:
+ closed_loops.append(True)
+ original_close()
+
+ loop.close = _tracked_close # type: ignore[method-assign]
+ return loop
+
+ async def _body() -> None:
+ # Loop is alive and current at body execution time.
+ running = asyncio.get_running_loop()
+ assert not running.is_closed()
+
+ with (
+ _patch_shared_engine(stub),
+ patch.object(asyncio, "new_event_loop", _counting_new_loop),
+ ):
+ for _ in range(3):
+ celery_tasks_pkg.run_async_celery_task(_body)
+
+ assert new_loop_calls == 3
+ assert closed_loops == [True, True, True]
+ # Each invocation disposed twice (before + after).
+ assert len(stub.dispose_loop_ids) == 6
+
+
+def test_runner_disposes_engine_even_when_body_raises() -> None:
+ """Cleanup MUST run on the failure path too — otherwise stale
+ connections leak into the next task and cause the original hang.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ class _BoomError(RuntimeError):
+ pass
+
+ async def _body() -> None:
+ raise _BoomError("kaboom")
+
+ with _patch_shared_engine(stub), pytest.raises(_BoomError):
+ run_async_celery_task(_body)
+
+ assert len(stub.dispose_loop_ids) == 2 # before + after still ran
+
+
+def test_runner_swallows_dispose_errors() -> None:
+ """A flaky engine.dispose() must NEVER take down a celery task.
+
+ Production scenario: the very first dispose (before the body runs)
+ might hit a partially-initialised engine; the helper logs and
+ moves on. The task body still runs; the result is still returned.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ class _AngryEngine:
+ def __init__(self) -> None:
+ self.calls = 0
+
+ async def dispose(self) -> None:
+ self.calls += 1
+ raise RuntimeError("dispose() blew up")
+
+ stub = _AngryEngine()
+
+ async def _body() -> int:
+ return 42
+
+ with _patch_shared_engine(stub):
+ assert run_async_celery_task(_body) == 42
+
+ assert stub.calls == 2 # before + after both attempted
+
+
+def test_runner_propagates_value_from_async_body() -> None:
+ """Sanity: pass-through of any pickleable celery return value."""
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ async def _body() -> dict[str, object]:
+ return {"status": "ready", "video_presentation_id": 19}
+
+ with _patch_shared_engine(stub):
+ out = run_async_celery_task(_body)
+
+ assert out == {"status": "ready", "video_presentation_id": 19}
+
+
+def test_video_presentation_task_uses_runner_helper() -> None:
+ """Defence-in-depth: confirm the celery task module imports
+ ``run_async_celery_task``. If a future refactor inlines a
+ ``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
+ the original hang will return.
+ """
+ # The module's task body should not contain a manual new_event_loop
+ # call — that's exactly what the helper exists to centralise.
+ import inspect
+
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ src = inspect.getsource(video_presentation_tasks)
+ assert "run_async_celery_task" in src, (
+ "video_presentation_tasks.py must use run_async_celery_task; "
+ "manual asyncio.new_event_loop() in a celery task hangs on the "
+ "shared SQLAlchemy pool when reused across tasks."
+ )
+ assert "asyncio.new_event_loop" not in src, (
+ "video_presentation_tasks.py contains a raw asyncio.new_event_loop "
+ "call — route every async task through run_async_celery_task to "
+ "avoid the stale-pool hang."
+ )
+
+
+def test_podcast_task_uses_runner_helper() -> None:
+ """Symmetric assertion for the podcast task — same root cause, same
+ fix, same regression risk.
+ """
+ import inspect
+
+ from app.tasks.celery_tasks import podcast_tasks
+
+ src = inspect.getsource(podcast_tasks)
+ assert "run_async_celery_task" in src
+ assert "asyncio.new_event_loop" not in src
+
+
+def test_runner_runs_shutdown_asyncgens_before_close() -> None:
+ """If the task body created any async generators that didn't get
+ fully iterated, we must still call ``loop.shutdown_asyncgens()``
+ before closing — otherwise we leak event-loop bound resources
+ that re-emerge as ``RuntimeError: Event loop is closed`` later.
+ """
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ async def _agen():
+ try:
+ yield 1
+ yield 2
+ finally:
+ pass
+
+ async def _body() -> None:
+ # Iterate the agen partially, then leave it dangling — exactly
+ # the situation shutdown_asyncgens() is designed to clean up.
+ async for v in _agen():
+ if v == 1:
+ break
+
+ with _patch_shared_engine(stub):
+ run_async_celery_task(_body)
+
+ # By the time the helper returns, garbage collection + shutdown_asyncgens
+ # should have ensured no live async-gen references remain. We don't
+ # assert agen.closed directly (it depends on GC ordering); the real
+ # contract is "no warnings, no event-loop-closed errors". A successful
+ # second invocation proves the loop was cleaned up properly.
+ with _patch_shared_engine(stub):
+ run_async_celery_task(_body)
+
+ # Force a GC pass to surface any 'coroutine was never awaited'
+ # warnings that would indicate the cleanup is broken.
+ gc.collect()
+
+
+def test_runner_uses_proactor_loop_on_windows() -> None:
+ """On Windows the celery worker preselects a Proactor policy so
+ subprocess (ffmpeg) calls work. The helper must not silently fall
+ back to a Selector loop and re-break video/podcast generation.
+ """
+ if not sys.platform.startswith("win"):
+ pytest.skip("Windows-specific event-loop policy assertion")
+
+ from app.tasks.celery_tasks import run_async_celery_task
+
+ stub = _StaleLoopEngine()
+
+ # Mirror the policy set at the top of every Windows celery task.
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
+
+ observed: list[str] = []
+
+ async def _body() -> None:
+ observed.append(type(asyncio.get_running_loop()).__name__)
+
+ with _patch_shared_engine(stub):
+ run_async_celery_task(_body)
+
+ assert observed == ["ProactorEventLoop"]
diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
index 38d6ba2ca..699297df1 100644
--- a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
+++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
@@ -113,6 +113,19 @@ async def _denying_billable_call(**kwargs):
yield SimpleNamespace() # pragma: no cover — for grammar only
+@contextlib.asynccontextmanager
+async def _settlement_failing_billable_call(**kwargs):
+ from app.services.billable_calls import BillingSettlementError
+
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+ raise BillingSettlementError(
+ usage_type=kwargs.get("usage_type", "?"),
+ user_id=kwargs["user_id"],
+ cause=RuntimeError("finalize failed"),
+ )
+
+
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@@ -187,8 +200,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
)
- assert call["thread_id"] == 99
- assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
+ # Background artifact audit rows intentionally omit the TokenUsage.thread_id
+ # FK to avoid coupling Celery audit commits to an active chat transaction.
+ assert "thread_id" not in call
+ assert call["call_details"] == {
+ "podcast_id": 7,
+ "title": "Test Podcast",
+ "thread_id": 99,
+ }
+ assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
@@ -279,6 +299,49 @@ async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypat
assert graph_invoked == [] # Graph never ran on denied reservation.
+@pytest.mark.asyncio
+async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=10)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(
+ podcast_tasks, "billable_call", _settlement_failing_billable_call
+ )
+
+ async def _fake_graph_invoke(state, config):
+ return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=10,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 10,
+ "reason": "billing_settlement_failed",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+
+
@pytest.mark.asyncio
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
"""If the resolver raises (e.g. search-space deleted), the task fails
diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py
new file mode 100644
index 000000000..792d059b0
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py
@@ -0,0 +1,119 @@
+"""Predicate-level test for the chat streaming safety net.
+
+The safety net in ``stream_new_chat`` rejects an image turn early with
+a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
+selected model is *known* to be text-only. The earlier round of this
+work used a strict opt-in flag (``supports_image_input`` defaulting to
+False on every YAML entry) which blocked vision-capable Azure GPT-5.x
+deployments — this is the regression we're fixing.
+
+The new predicate is :func:`is_known_text_only_chat_model`, which
+returns True only when LiteLLM's authoritative model map *explicitly*
+sets ``supports_vision=False``. Anything else (vision True, missing
+key, exception) returns False so the request flows through to the
+provider.
+
+We exercise the predicate directly here rather than driving the full
+``stream_new_chat`` generator — covering the gate in isolation keeps
+the test focused on the regression while the generator's wider behavior
+is exercised by the integration suite.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from app.services.provider_capabilities import is_known_text_only_chat_model
+
+pytestmark = pytest.mark.unit
+
+
+def test_safety_net_does_not_fire_for_azure_gpt_4o():
+ """Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
+ vision-capable per LiteLLM's model map. The previous round's
+ blanket-False default blocked it; the new predicate must NOT mark
+ it text-only."""
+ assert (
+ is_known_text_only_chat_model(
+ provider="AZURE_OPENAI",
+ model_name="my-azure-deployment",
+ base_model="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_safety_net_does_not_fire_for_unknown_model():
+ """Default-pass on unknown — the safety net only blocks definitive
+ text-only confirmations. A freshly added third-party model that
+ LiteLLM doesn't know about must flow through to the provider."""
+ assert (
+ is_known_text_only_chat_model(
+ provider="CUSTOM",
+ custom_provider="brand_new_proxy",
+ model_name="brand-new-model-x9",
+ )
+ is False
+ )
+
+
+def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
+ """Transient ``litellm.get_model_info`` exception ≠ block. The
+ helper swallows the error and treats it as 'unknown' → False."""
+ import app.services.provider_capabilities as pc
+
+ def _raise(**_kwargs):
+ raise RuntimeError("intentional test failure")
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
+
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="gpt-4o",
+ )
+ is False
+ )
+
+
+def test_safety_net_fires_only_on_explicit_false(monkeypatch):
+ """Stub LiteLLM to assert the only path that returns True is the
+ explicit ``supports_vision=False`` case. Anything else (True,
+ None, missing key) returns False from the predicate."""
+ import app.services.provider_capabilities as pc
+
+ def _info_explicit_false(**_kwargs):
+ return {"supports_vision": False, "max_input_tokens": 8192}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="text-only-stub",
+ )
+ is True
+ )
+
+ def _info_true(**_kwargs):
+ return {"supports_vision": True}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="vision-stub",
+ )
+ is False
+ )
+
+ def _info_missing(**_kwargs):
+ return {"max_input_tokens": 8192}
+
+ monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
+ assert (
+ is_known_text_only_chat_model(
+ provider="OPENAI",
+ model_name="missing-key-stub",
+ )
+ is False
+ )
diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
index 671f57ae4..423b64ddb 100644
--- a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
+++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
@@ -105,6 +105,19 @@ async def _denying_billable_call(**kwargs):
yield SimpleNamespace() # pragma: no cover
+@contextlib.asynccontextmanager
+async def _settlement_failing_billable_call(**kwargs):
+ from app.services.billable_calls import BillingSettlementError
+
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+ raise BillingSettlementError(
+ usage_type=kwargs.get("usage_type", "?"),
+ user_id=kwargs["user_id"],
+ cause=RuntimeError("finalize failed"),
+ )
+
+
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@@ -176,11 +189,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
)
- assert call["thread_id"] == 99
+ # Background artifact audit rows intentionally omit the TokenUsage.thread_id
+ # FK to avoid coupling Celery audit commits to an active chat transaction.
+ assert "thread_id" not in call
assert call["call_details"] == {
"video_presentation_id": 11,
"title": "Test Presentation",
+ "thread_id": 99,
}
+ assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
@@ -280,6 +297,57 @@ async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch
assert graph_invoked == []
+@pytest.mark.asyncio
+async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=14)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "billable_call",
+ _settlement_failing_billable_call,
+ )
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=14,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 14,
+ "reason": "billing_settlement_failed",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+
+
@pytest.mark.asyncio
async def test_resolver_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx
index ffb0e4dc8..3b9d9a526 100644
--- a/surfsense_web/components/assistant-ui/assistant-message.tsx
+++ b/surfsense_web/components/assistant-ui/assistant-message.tsx
@@ -477,9 +477,7 @@ const MessageInfoDropdown: FC = () => {
{counts.total_tokens.toLocaleString()} tokens
- {costMicros && costMicros > 0
- ? ` · ${formatTurnCost(costMicros)}`
- : ""}
+ {costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""}
);
diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx
index 1a0f8c5ba..44f3feb7a 100644
--- a/surfsense_web/components/new-chat/model-selector.tsx
+++ b/surfsense_web/components/new-chat/model-selector.tsx
@@ -19,6 +19,7 @@ import {
import type React from "react";
import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
+import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
import {
globalImageGenConfigsAtom,
imageGenConfigsAtom,
@@ -461,6 +462,18 @@ export function ModelSelector({
const { data: visionUserConfigs, isLoading: visionUserLoading } =
useAtomValue(visionLLMConfigsAtom);
+ // Pending image attachments on the composer. Used to surface an
+ // amber "No image" hint on chat models the catalog reports as
+ // non-vision (`supports_image_input=false`) when the next message
+ // will carry an image. The hint is purely advisory: selection,
+ // focus, and click handling are unaffected. The backend's safety
+ // net (`is_known_text_only_chat_model`) is the actual block, and
+ // it only fires when LiteLLM *explicitly* marks a model as
+ // text-only — so a model that's secretly capable but hasn't been
+ // annotated will still flow through to the provider.
+ const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
+ const hasPendingImages = pendingUserImageUrls.length > 0;
+
const isLoading =
llmUserLoading ||
llmGlobalLoading ||
@@ -984,6 +997,21 @@ export function ModelSelector({
const isSelected = getSelectedId() === config.id;
const isFocused = focusedIndex === index;
const hasCitations = "citations_enabled" in config && !!config.citations_enabled;
+ // Chat-tab only: surface an amber "No image" hint when the
+ // composer carries images and the catalog reports the model as
+ // non-vision. This is purely advisory — selection is *not*
+ // blocked. The backend's narrow safety net
+ // (`is_known_text_only_chat_model`) is the source of truth for
+ // rejecting image turns, and it only fires when LiteLLM
+ // explicitly marks the model as text-only. A model surfaced as
+ // `supports_image_input=false` here may still be capable in
+ // practice (unknown / unmapped LiteLLM entry), so we let the
+ // user pick it and the provider response decide.
+ const isImageIncompatibleChatModel =
+ activeTab === "llm" &&
+ hasPendingImages &&
+ "supports_image_input" in config &&
+ (config as Record).supports_image_input === false;
return (
handleSelectItem(item)}
onKeyDown={
isMobile
@@ -1005,9 +1038,8 @@ export function ModelSelector({
}
onMouseEnter={() => setFocusedIndex(index)}
className={cn(
- "group flex items-center gap-2.5 px-3 py-2 rounded-xl cursor-pointer",
- "transition-all duration-150 mx-2",
- "hover:bg-accent/40",
+ "group flex items-center gap-2.5 px-3 py-2 rounded-xl",
+ "transition-all duration-150 mx-2 cursor-pointer hover:bg-accent/40",
isSelected && "bg-primary/6 dark:bg-primary/8",
isFocused && "bg-accent/50"
)}
@@ -1053,6 +1085,14 @@ export function ModelSelector({
Free
) : null}
+ {isImageIncompatibleChatModel && (
+
+ No image
+
+ )}
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx
index 127b79167..156ef9134 100644
--- a/surfsense_web/components/pricing/pricing-section.tsx
+++ b/surfsense_web/components/pricing/pricing-section.tsx
@@ -250,8 +250,8 @@ function PricingFAQ() {
Frequently Asked Questions
- Everything you need to know about SurfSense pages, premium credit, and billing.
- Can't find what you need? Reach out at{" "}
+ Everything you need to know about SurfSense pages, premium credit, and billing. Can't
+ find what you need? Reach out at{" "}
rohan@surfsense.com
diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx
index ced97464e..d4afa698b 100644
--- a/surfsense_web/components/settings/image-model-manager.tsx
+++ b/surfsense_web/components/settings/image-model-manager.tsx
@@ -22,6 +22,7 @@ import {
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
+import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
@@ -190,8 +191,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
? "model"
: "models"}
{" "}
- available from your administrator.{" "}
- {(() => {
+ available from your administrator. {(() => {
const nonAuto = globalConfigs.filter(
(g) => !("is_auto_mode" in g && g.is_auto_mode)
);
@@ -214,6 +214,75 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
)}
+ {/* Global Image Models — read-only cards with per-model Free/Premium
+ badges. Mirrors the badge palette used by the chat role selector
+ (`llm-role-manager.tsx`) so the meaning is consistent across
+ every model-configuration surface (chat / image / vision). */}
+ {!isLoading &&
+ globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
+
+
+ Global Image Models
+
+
+ {globalConfigs
+ .filter((g) => !("is_auto_mode" in g && g.is_auto_mode))
+ .map((cfg) => {
+ const billingTier =
+ ("billing_tier" in cfg &&
+ typeof (cfg as { billing_tier?: string }).billing_tier === "string" &&
+ (cfg as { billing_tier?: string }).billing_tier) ||
+ "free";
+ const isPremium = billingTier === "premium";
+ return (
+
+
+
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })}
+
+
+
+ {cfg.name}
+
+ {isPremium ? (
+
+ Premium
+
+ ) : (
+
+ Free
+
+ )}
+
+
+ {cfg.description && (
+
+ {cfg.description}
+
+ )}
+
+
+ {cfg.model_name}
+
+
+
+
+ );
+ })}
+
+
+ )}
+
{/* Loading Skeleton */}
{isLoading && (
diff --git a/surfsense_web/components/settings/more-pages-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx
index 8de61b0c7..5635c3314 100644
--- a/surfsense_web/components/settings/more-pages-content.tsx
+++ b/surfsense_web/components/settings/more-pages-content.tsx
@@ -70,9 +70,7 @@ export function MorePagesContent() {
Get Free Pages
-
- Earn bonus pages by completing tasks
-
+
Earn bonus pages by completing tasks
diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx
index 886d71008..34aa531fd 100644
--- a/surfsense_web/components/settings/vision-model-manager.tsx
+++ b/surfsense_web/components/settings/vision-model-manager.tsx
@@ -22,6 +22,7 @@ import {
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
+import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
@@ -191,8 +192,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
? "model"
: "models"}
{" "}
- available from your administrator.{" "}
- {(() => {
+ available from your administrator. {(() => {
const nonAuto = globalConfigs.filter(
(g) => !("is_auto_mode" in g && g.is_auto_mode)
);
@@ -215,6 +215,75 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
)}
+ {/* Global Vision Models — read-only cards with per-model Free/Premium
+ badges. Mirrors the badge palette used by the chat role selector
+ (`llm-role-manager.tsx`) so the meaning is consistent across
+ every model-configuration surface (chat / image / vision). */}
+ {!isLoading &&
+ globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
+
+
+ Global Vision Models
+
+
+ {globalConfigs
+ .filter((g) => !("is_auto_mode" in g && g.is_auto_mode))
+ .map((cfg) => {
+ const billingTier =
+ ("billing_tier" in cfg &&
+ typeof (cfg as { billing_tier?: string }).billing_tier === "string" &&
+ (cfg as { billing_tier?: string }).billing_tier) ||
+ "free";
+ const isPremium = billingTier === "premium";
+ return (
+
+
+
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })}
+
+
+
+ {cfg.name}
+
+ {isPremium ? (
+
+ Premium
+
+ ) : (
+
+ Free
+
+ )}
+
+
+ {cfg.description && (
+
+ {cfg.description}
+
+ )}
+
+
+ {cfg.model_name}
+
+
+
+
+ );
+ })}
+
+
+ )}
+
{isLoading && (
diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx
index 02f53efad..e8fff2873 100644
--- a/surfsense_web/components/tool-ui/generate-podcast.tsx
+++ b/surfsense_web/components/tool-ui/generate-podcast.tsx
@@ -416,9 +416,19 @@ export const GeneratePodcastToolUI = ({
return
;
}
- // Already generating - show simple warning, don't create another poller
- // The FIRST tool call will display the podcast when ready
- // (new: "generating", legacy: "already_generating")
+ // Pending/generating rows have a stable podcast_id, so the card can poll
+ // independently while the chat stream finishes.
+ if (
+ (result.status === "pending" ||
+ result.status === "generating" ||
+ result.status === "processing") &&
+ result.podcast_id
+ ) {
+ return
;
+ }
+
+ // Legacy duplicate/no-ID result - show a simple warning, don't create
+ // another poller. The first tool call will display the podcast when ready.
if (result.status === "generating" || result.status === "already_generating") {
return (
@@ -432,11 +442,6 @@ export const GeneratePodcastToolUI = ({
);
}
- // Pending - poll for completion (new: "pending" with podcast_id)
- if (result.status === "pending" && result.podcast_id) {
- return
;
- }
-
// Ready with podcast_id (new: "ready", legacy: "success")
if ((result.status === "ready" || result.status === "success") && result.podcast_id) {
return
;
diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx
index 790e5c00e..f72cb3a42 100644
--- a/surfsense_web/contexts/login-gate.tsx
+++ b/surfsense_web/contexts/login-gate.tsx
@@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) {
Create a free account to {feature}
- Get $5 of premium credit, save chat history, upload documents, use all AI tools,
- and connect 30+ integrations.
+ Get $5 of premium credit, save chat history, upload documents, use all AI tools, and
+ connect 30+ integrations.
diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts
index 2d6b70eda..b52b98ae4 100644
--- a/surfsense_web/contracts/types/new-llm-config.types.ts
+++ b/surfsense_web/contracts/types/new-llm-config.types.ts
@@ -65,6 +65,13 @@ export const newLLMConfig = z.object({
created_at: z.string(),
search_space_id: z.number(),
user_id: z.string(),
+
+ // Capability flag — derived server-side at the route boundary from
+ // LiteLLM's authoritative model map. There is no DB column. Default
+ // `true` is the conservative-allow stance for unknown / unmapped
+ // BYOK rows; the streaming-task safety net is the only place a
+ // `false` actually blocks a request.
+ supports_image_input: z.boolean().default(true),
});
/**
@@ -74,11 +81,16 @@ export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true });
/**
* Create NewLLMConfig
+ *
+ * `supports_image_input` is omitted because it is derived server-side
+ * from LiteLLM's model map at read time — there is no DB column to
+ * persist a client-supplied value into.
*/
export const createNewLLMConfigRequest = newLLMConfig.omit({
id: true,
created_at: true,
user_id: true,
+ supports_image_input: true,
});
export const createNewLLMConfigResponse = newLLMConfig;
@@ -114,6 +126,8 @@ export const updateNewLLMConfigRequest = z.object({
created_at: true,
search_space_id: true,
user_id: true,
+ // Derived server-side; not part of the writable surface.
+ supports_image_input: true,
})
.partial(),
});
@@ -172,6 +186,16 @@ export const globalNewLLMConfig = z.object({
seo_title: z.string().nullable().optional(),
seo_description: z.string().nullable().optional(),
quota_reserve_tokens: z.number().nullable().optional(),
+ // Capability flag — true when the model can accept image inputs.
+ // Resolved server-side (OpenRouter dynamic configs use the OR
+ // `architecture.input_modalities` field; YAML / BYOK use LiteLLM's
+ // authoritative `supports_vision` map). The chat selector renders
+ // an amber "No image" hint when this is false and there are
+ // pending image attachments, but does not block selection — the
+ // backend safety net only rejects when LiteLLM *explicitly* marks
+ // the model as text-only, so unknown / new models still flow
+ // through. Default `true` matches that conservative-allow stance.
+ supports_image_input: z.boolean().default(true),
});
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
@@ -259,6 +283,9 @@ export const globalImageGenConfig = z.object({
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
billing_tier: z.string().default("free"),
+ // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's
+ // Free/Premium badge logic lights up automatically for image-gen too.
+ is_premium: z.boolean().default(false),
quota_reserve_micros: z.number().nullable().optional(),
});
@@ -341,6 +368,9 @@ export const globalVisionLLMConfig = z.object({
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
billing_tier: z.string().default("free"),
+ // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's
+ // Free/Premium badge logic lights up automatically for vision too.
+ is_premium: z.boolean().default(false),
quota_reserve_tokens: z.number().nullable().optional(),
input_cost_per_token: z.number().nullable().optional(),
output_cost_per_token: z.number().nullable().optional(),
diff --git a/surfsense_web/next.config.ts b/surfsense_web/next.config.ts
index 5414d548d..6cfcb5187 100644
--- a/surfsense_web/next.config.ts
+++ b/surfsense_web/next.config.ts
@@ -18,6 +18,12 @@ const nextConfig: NextConfig = {
},
images: {
remotePatterns: [
+ {
+ protocol: "http",
+ hostname: "localhost",
+ port: "8000",
+ pathname: "/api/v1/image-generations/**",
+ },
{
protocol: "https",
hostname: "**",
From cea8618aed74840fe7d48a669dfd4fc07e0039cc Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 21:16:03 -0700
Subject: [PATCH 04/12] fix: fixed composio issues
---
.../new_chat/tools/gmail/composio_helpers.py | 41 ++
.../new_chat/tools/gmail/create_draft.py | 61 ++-
.../agents/new_chat/tools/gmail/read_email.py | 48 +++
.../new_chat/tools/gmail/search_emails.py | 72 +++-
.../agents/new_chat/tools/gmail/send_email.py | 61 ++-
.../new_chat/tools/gmail/trash_email.py | 50 ++-
.../new_chat/tools/gmail/update_draft.py | 116 ++++--
.../tools/google_calendar/create_event.py | 68 +++-
.../tools/google_calendar/delete_event.py | 50 ++-
.../tools/google_calendar/search_events.py | 99 +++--
.../tools/google_calendar/update_event.py | 84 +++--
.../tools/google_drive/create_file.py | 60 ++-
.../new_chat/tools/google_drive/trash_file.py | 32 +-
.../app/agents/new_chat/tools/hitl.py | 1 +
.../app/routes/composio_routes.py | 41 +-
.../app/services/composio_service.py | 105 +++++-
.../services/gmail/tool_metadata_service.py | 122 +++++-
.../google_calendar/kb_sync_service.py | 65 +++-
.../google_calendar/tool_metadata_service.py | 282 ++++++++++++--
.../google_drive/tool_metadata_service.py | 96 ++++-
.../app/tasks/chat/stream_new_chat.py | 46 ++-
.../google_calendar_indexer.py | 59 ++-
.../google_drive_indexer.py | 356 +++++++++++++-----
.../google_gmail_indexer.py | 146 ++++++-
.../tasks/chat/test_tool_input_streaming.py | 56 ++-
25 files changed, 1756 insertions(+), 461 deletions(-)
create mode 100644 surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py
new file mode 100644
index 000000000..0ca1191a4
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py
@@ -0,0 +1,41 @@
+from typing import Any
+
+from app.db import SearchSourceConnector
+from app.services.composio_service import ComposioService
+
+
+def split_recipients(value: str | None) -> list[str]:
+ if not value:
+ return []
+ return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
+
+
+def unwrap_composio_data(data: Any) -> Any:
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner)
+ return inner
+ return data
+
+
+async def execute_composio_gmail_tool(
+ connector: SearchSourceConnector,
+ user_id: str,
+ tool_name: str,
+ params: dict[str, Any],
+) -> tuple[Any, str | None]:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return None, "Composio connected account ID not found for this Gmail connector."
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Gmail error")
+
+ return unwrap_composio_data(result.get("data")), None
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
index 0bd044695..7e9ddf7d3 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
@@ -157,16 +157,13 @@ def create_create_gmail_draft_tool(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
- if (
+ is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
@@ -208,10 +205,6 @@ def create_create_gmail_draft_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
@@ -222,15 +215,43 @@ def create_create_gmail_draft_tool(
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
- created = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .drafts()
- .create(userId="me", body={"message": {"raw": raw}})
- .execute()
- ),
- )
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
+ )
+
+ created, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_CREATE_EMAIL_DRAFT",
+ {
+ "user_id": "me",
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(created, dict):
+ created = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ created = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .drafts()
+ .create(userId="me", body={"message": {"raw": raw}})
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
index deec1627c..1964181e4 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
@@ -50,6 +50,54 @@ def create_read_gmail_email_tool(
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ ):
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found.",
+ }
+
+ from app.agents.new_chat.tools.gmail.search_emails import (
+ _format_gmail_summary,
+ )
+ from app.services.composio_service import ComposioService
+
+ service = ComposioService()
+ detail, error = await service.get_gmail_message_detail(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ message_id=message_id,
+ )
+ if error:
+ return {"status": "error", "message": error}
+ if not detail:
+ return {
+ "status": "not_found",
+ "message": f"Email with ID '{message_id}' not found.",
+ }
+
+ summary = _format_gmail_summary(detail)
+ content = (
+ f"# {summary['subject']}\n\n"
+ f"**From:** {summary['from']}\n"
+ f"**To:** {summary['to']}\n"
+ f"**Date:** {summary['date']}\n\n"
+ f"## Message Content\n\n"
+ f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
+ f"## Message Details\n\n"
+ f"- **Message ID:** {summary['message_id']}\n"
+ f"- **Thread ID:** {summary['thread_id']}\n"
+ )
+ return {
+ "status": "success",
+ "message_id": summary["message_id"] or message_id,
+ "content": content,
+ }
+
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
creds = _build_credentials(connector)
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
index 2e363609e..59886159a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
@@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- from app.utils.google_credentials import build_composio_credentials
-
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- raise ValueError("Composio connected account ID not found.")
- return build_composio_credentials(cca_id)
+ raise ValueError("Composio connectors must use Composio tool execution.")
from google.oauth2.credentials import Credentials
@@ -67,6 +62,63 @@ def _build_credentials(connector: SearchSourceConnector):
)
+def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
+ headers = message.get("payload", {}).get("headers", [])
+ return {
+ header.get("name", "").lower(): header.get("value", "")
+ for header in headers
+ if isinstance(header, dict)
+ }
+
+
+def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
+ headers = _gmail_headers(message)
+ return {
+ "message_id": message.get("id") or message.get("messageId"),
+ "thread_id": message.get("threadId"),
+ "subject": message.get("subject") or headers.get("subject", "No Subject"),
+ "from": message.get("sender") or headers.get("from", "Unknown"),
+ "to": message.get("to") or headers.get("to", ""),
+ "date": message.get("messageTimestamp") or headers.get("date", ""),
+ "snippet": message.get("snippet") or message.get("messageText", "")[:300],
+ "labels": message.get("labelIds", []),
+ }
+
+
+async def _search_composio_gmail(
+ connector: SearchSourceConnector,
+ user_id: str,
+ query: str,
+ max_results: int,
+) -> dict[str, Any]:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found.",
+ }
+
+ from app.services.composio_service import ComposioService
+
+ service = ComposioService()
+ messages, _next_token, _estimate, error = await service.get_gmail_messages(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ query=query,
+ max_results=max_results,
+ )
+ if error:
+ return {"status": "error", "message": error}
+
+ emails = [_format_gmail_summary(message) for message in messages]
+ return {
+ "status": "success",
+ "emails": emails,
+ "total": len(emails),
+ "message": "No emails found." if not emails else None,
+ }
+
+
def create_search_gmail_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
@@ -110,6 +162,14 @@ def create_search_gmail_tool(
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ ):
+ return await _search_composio_gmail(
+ connector, str(user_id), query, max_results
+ )
+
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
index c3f0999f4..79ff2d9c7 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
@@ -158,16 +158,13 @@ def create_send_gmail_email_tool(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
- if (
+ is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
@@ -209,10 +206,6 @@ def create_send_gmail_email_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
@@ -223,15 +216,43 @@ def create_send_gmail_email_tool(
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
- sent = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .messages()
- .send(userId="me", body={"raw": raw})
- .execute()
- ),
- )
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
+ )
+
+ sent, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_SEND_EMAIL",
+ {
+ "user_id": "me",
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(sent, dict):
+ sent = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ sent = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .messages()
+ .send(userId="me", body={"raw": raw})
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
index 1f1f6227a..4e710dc72 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
@@ -158,16 +158,13 @@ def create_trash_gmail_email_tool(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
- if (
+ is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
@@ -209,20 +206,33 @@ def create_trash_gmail_email_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
try:
- await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .messages()
- .trash(userId="me", id=final_message_id)
- .execute()
- ),
- )
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ )
+
+ _trashed, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_MOVE_TO_TRASH",
+ {"user_id": "me", "message_id": final_message_id},
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .messages()
+ .trash(userId="me", id=final_message_id)
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
index 91178cd21..50956f03a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
@@ -188,16 +188,13 @@ def create_update_gmail_draft_tool(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
)
- if (
+ is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
@@ -239,18 +236,22 @@ def create_update_gmail_draft_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
-
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
- final_draft_id = await _find_draft_id_by_message(
- gmail_service, message_id
- )
+ if is_composio_gmail:
+ final_draft_id = await _find_composio_draft_id_by_message(
+ connector, user_id, message_id
+ )
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ final_draft_id = await _find_draft_id_by_message(
+ gmail_service, message_id
+ )
if not final_draft_id:
return {
@@ -272,19 +273,48 @@ def create_update_gmail_draft_tool(
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
- updated = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .drafts()
- .update(
- userId="me",
- id=final_draft_id,
- body={"message": {"raw": raw}},
- )
- .execute()
- ),
- )
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
+ )
+
+ updated, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_UPDATE_DRAFT",
+ {
+ "user_id": "me",
+ "draft_id": final_draft_id,
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(updated, dict):
+ updated = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ updated = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .drafts()
+ .update(
+ userId="me",
+ id=final_draft_id,
+ body={"message": {"raw": raw}},
+ )
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
@@ -408,3 +438,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}")
return None
+
+
+async def _find_composio_draft_id_by_message(
+ connector: Any, user_id: str, message_id: str
+) -> str | None:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ )
+
+ page_token = ""
+ while True:
+ params: dict[str, Any] = {
+ "user_id": "me",
+ "max_results": 100,
+ "verbose": False,
+ }
+ if page_token:
+ params["page_token"] = page_token
+
+ data, error = await execute_composio_gmail_tool(
+ connector, user_id, "GMAIL_LIST_DRAFTS", params
+ )
+ if error or not isinstance(data, dict):
+ return None
+
+ for draft in data.get("drafts", []):
+ if draft.get("message", {}).get("id") == message_id:
+ return draft.get("id")
+
+ page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
+ if not page_token:
+ return None
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
index 37bcf083e..0a4720f6f 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
@@ -168,16 +168,13 @@ def create_create_calendar_event_tool(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
)
- if (
+ is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
@@ -211,10 +208,6 @@ def create_create_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
-
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
@@ -231,14 +224,51 @@ def create_create_calendar_event_tool(
]
try:
- created = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .insert(calendarId="primary", body=event_body)
- .execute()
- ),
- )
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_params = {
+ "calendar_id": "primary",
+ "summary": final_summary,
+ "start_datetime": final_start_datetime,
+ "end_datetime": final_end_datetime,
+ "timezone": tz,
+ "attendees": final_attendees or [],
+ }
+ if final_description:
+ composio_params["description"] = final_description
+ if final_location:
+ composio_params["location"] = final_location
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_CREATE_EVENT",
+ params=composio_params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ created = composio_result.get("data", {})
+ if isinstance(created, dict):
+ created = created.get("data", created)
+ if isinstance(created, dict):
+ created = created.get("response_data", created)
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ created = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .insert(calendarId="primary", body=event_body)
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
index 4d9d69b4b..53596ac0f 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
@@ -159,16 +159,13 @@ def create_delete_calendar_event_tool(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
- if (
+ is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
@@ -202,19 +199,34 @@ def create_delete_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
-
try:
- await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .delete(calendarId="primary", eventId=final_event_id)
- .execute()
- ),
- )
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_DELETE_EVENT",
+ params={"calendar_id": "primary", "event_id": final_event_id},
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .delete(calendarId="primary", eventId=final_event_id)
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
index dc6adb822..b5194d15f 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
@@ -16,6 +16,35 @@ _CALENDAR_TYPES = [
]
+def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
+ if "T" in value:
+ return value
+ time = "23:59:59" if is_end else "00:00:00"
+ return f"{value}T{time}Z"
+
+
+def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ events = []
+ for ev in events_raw:
+ start = ev.get("start", {})
+ end = ev.get("end", {})
+ attendees_raw = ev.get("attendees", [])
+ events.append(
+ {
+ "event_id": ev.get("id"),
+ "summary": ev.get("summary", "No Title"),
+ "start": start.get("dateTime") or start.get("date", ""),
+ "end": end.get("dateTime") or end.get("date", ""),
+ "location": ev.get("location", ""),
+ "description": ev.get("description", ""),
+ "html_link": ev.get("htmlLink", ""),
+ "attendees": [a.get("email", "") for a in attendees_raw[:10]],
+ "status": ev.get("status", ""),
+ }
+ )
+ return events
+
+
def create_search_calendar_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
@@ -61,22 +90,47 @@ def create_search_calendar_events_tool(
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
- creds = _build_credentials(connector)
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ ):
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
- from app.connectors.google_calendar_connector import GoogleCalendarConnector
+ from app.services.composio_service import ComposioService
- cal = GoogleCalendarConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
+ events_raw, error = await ComposioService().get_calendar_events(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ time_min=_to_calendar_boundary(start_date, is_end=False),
+ time_max=_to_calendar_boundary(end_date, is_end=True),
+ max_results=max_results,
+ )
+ if not events_raw and not error:
+ error = "No events found in the specified date range."
+ else:
+ creds = _build_credentials(connector)
- events_raw, error = await cal.get_all_primary_calendar_events(
- start_date=start_date,
- end_date=end_date,
- max_results=max_results,
- )
+ from app.connectors.google_calendar_connector import (
+ GoogleCalendarConnector,
+ )
+
+ cal = GoogleCalendarConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
+ )
+
+ events_raw, error = await cal.get_all_primary_calendar_events(
+ start_date=start_date,
+ end_date=end_date,
+ max_results=max_results,
+ )
if error:
if (
@@ -97,24 +151,7 @@ def create_search_calendar_events_tool(
}
return {"status": "error", "message": error}
- events = []
- for ev in events_raw:
- start = ev.get("start", {})
- end = ev.get("end", {})
- attendees_raw = ev.get("attendees", [])
- events.append(
- {
- "event_id": ev.get("id"),
- "summary": ev.get("summary", "No Title"),
- "start": start.get("dateTime") or start.get("date", ""),
- "end": end.get("dateTime") or end.get("date", ""),
- "location": ev.get("location", ""),
- "description": ev.get("description", ""),
- "html_link": ev.get("htmlLink", ""),
- "attendees": [a.get("email", "") for a in attendees_raw[:10]],
- "status": ev.get("status", ""),
- }
- )
+ events = _format_calendar_events(events_raw)
return {"status": "success", "events": events, "total": len(events)}
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
index 259f52bba..1dba36c20 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
@@ -192,16 +192,13 @@ def create_update_calendar_event_tool(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
- if (
+ is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- creds = build_composio_credentials(cca_id)
- else:
+ if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
@@ -235,10 +232,6 @@ def create_update_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None,
)
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
-
update_body: dict[str, Any] = {}
if final_new_summary is not None:
update_body["summary"] = final_new_summary
@@ -264,18 +257,65 @@ def create_update_calendar_event_tool(
}
try:
- updated = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .patch(
- calendarId="primary",
- eventId=final_event_id,
- body=update_body,
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_params: dict[str, Any] = {
+ "calendar_id": "primary",
+ "event_id": final_event_id,
+ }
+ if final_new_summary is not None:
+ composio_params["summary"] = final_new_summary
+ if final_new_start_datetime is not None:
+ composio_params["start_time"] = final_new_start_datetime
+ if final_new_end_datetime is not None:
+ composio_params["end_time"] = final_new_end_datetime
+ if final_new_description is not None:
+ composio_params["description"] = final_new_description
+ if final_new_location is not None:
+ composio_params["location"] = final_new_location
+ if final_new_attendees is not None:
+ composio_params["attendees"] = [
+ e.strip() for e in final_new_attendees if e.strip()
+ ]
+ if not _is_date_only(
+ final_new_start_datetime or final_new_end_datetime or ""
+ ):
+ composio_params["timezone"] = context.get("timezone", "UTC")
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_PATCH_EVENT",
+ params=composio_params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
)
- .execute()
- ),
- )
+ updated = composio_result.get("data", {})
+ if isinstance(updated, dict):
+ updated = updated.get("data", updated)
+ if isinstance(updated, dict):
+ updated = updated.get("response_data", updated)
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ updated = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .patch(
+ calendarId="primary",
+ eventId=final_event_id,
+ body=update_body,
+ )
+ .execute()
+ ),
+ )
except Exception as api_err:
from googleapiclient.errors import HttpError
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
index f36db8f3f..2becec100 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
@@ -179,29 +179,59 @@ def create_create_google_drive_file_tool(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
- pre_built_creds = None
- if (
+ is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
-
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Drive connector.",
+ }
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
- credentials=pre_built_creds,
)
try:
- created = await client.create_file(
- name=final_name,
- mime_type=mime_type,
- parent_folder_id=final_parent_folder_id,
- content=final_content,
- )
+ if is_composio_drive:
+ from app.services.composio_service import ComposioService
+
+ params: dict[str, Any] = {
+ "name": final_name,
+ "mimeType": mime_type,
+ "fields": "id,name,webViewLink,mimeType",
+ }
+ if final_parent_folder_id:
+ params["parents"] = [final_parent_folder_id]
+ if final_content:
+ params["description"] = final_content[:4096]
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLEDRIVE_CREATE_FILE",
+ params=params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ raise RuntimeError(
+ result.get("error", "Unknown Composio Drive error")
+ )
+ created = result.get("data", {})
+ if isinstance(created, dict):
+ created = created.get("data", created)
+ if isinstance(created, dict):
+ created = created.get("response_data", created)
+ if not isinstance(created, dict):
+ created = {}
+ else:
+ created = await client.create_file(
+ name=final_name,
+ mime_type=mime_type,
+ parent_folder_id=final_parent_folder_id,
+ content=final_content,
+ )
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
index 832afff0d..3c404527e 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
@@ -158,24 +158,38 @@ def create_delete_google_drive_file_tool(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
- pre_built_creds = None
- if (
+ is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- from app.utils.google_credentials import build_composio_credentials
-
+ )
+ if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Drive connector.",
+ }
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
- credentials=pre_built_creds,
)
try:
- await client.trash_file(file_id=final_file_id)
+ if is_composio_drive:
+ from app.services.composio_service import ComposioService
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLEDRIVE_TRASH_FILE",
+ params={"file_id": final_file_id},
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ raise RuntimeError(
+ result.get("error", "Unknown Composio Drive error")
+ )
+ else:
+ await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py
index 92248c2c9..5b64929de 100644
--- a/surfsense_backend/app/agents/new_chat/tools/hitl.py
+++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py
@@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
+ "create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
diff --git a/surfsense_backend/app/routes/composio_routes.py b/surfsense_backend/app/routes/composio_routes.py
index 4bf360365..7bc2addf8 100644
--- a/surfsense_backend/app/routes/composio_routes.py
+++ b/surfsense_backend/app/routes/composio_routes.py
@@ -649,13 +649,9 @@ async def list_composio_drive_folders(
"""
List folders AND files in user's Google Drive via Composio.
- Uses the same GoogleDriveClient / list_folder_contents path as the native
- connector, with Composio-sourced credentials. This means auth errors
- propagate identically (Google returns 401 → exception → auth_expired flag).
+ Uses Composio's Google Drive tool execution path so managed OAuth tokens
+ do not need to be exposed through connected account state.
"""
- from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
- from app.utils.google_credentials import build_composio_credentials
-
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@@ -689,10 +685,37 @@ async def list_composio_drive_folders(
detail="Composio connected account not found. Please reconnect the connector.",
)
- credentials = build_composio_credentials(composio_connected_account_id)
- drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
+ service = ComposioService()
+ entity_id = f"surfsense_{user.id}"
+ items = []
+ page_token = None
+ error = None
- items, error = await list_folder_contents(drive_client, parent_id=parent_id)
+ while True:
+ page_items, next_token, page_error = await service.get_drive_files(
+ connected_account_id=composio_connected_account_id,
+ entity_id=entity_id,
+ folder_id=parent_id,
+ page_token=page_token,
+ page_size=100,
+ )
+ if page_error:
+ error = page_error
+ break
+
+ items.extend(page_items)
+ if not next_token:
+ break
+ page_token = next_token
+
+ for item in items:
+ item["isFolder"] = (
+ item.get("mimeType") == "application/vnd.google-apps.folder"
+ )
+
+ items.sort(
+ key=lambda item: (not item["isFolder"], item.get("name", "").lower())
+ )
if error:
error_lower = error.lower()
diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py
index a8abe4aa8..edfab1d15 100644
--- a/surfsense_backend/app/services/composio_service.py
+++ b/surfsense_backend/app/services/composio_service.py
@@ -408,12 +408,37 @@ class ComposioService:
files = []
next_token = None
if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ response_data = (
+ inner_data.get("response_data", {})
+ if isinstance(inner_data, dict)
+ else {}
+ )
# Try direct access first, then nested
- files = data.get("files", []) or data.get("data", {}).get("files", [])
+ files = (
+ data.get("files", [])
+ or (
+ inner_data.get("files", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("files", [])
+ )
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
- or data.get("data", {}).get("nextPageToken")
+ or (
+ inner_data.get("nextPageToken")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or (
+ inner_data.get("next_page_token")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or response_data.get("nextPageToken")
+ or response_data.get("next_page_token")
)
elif isinstance(data, list):
files = data
@@ -819,24 +844,61 @@ class ComposioService:
next_token = None
result_size_estimate = None
if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ response_data = (
+ inner_data.get("response_data", {})
+ if isinstance(inner_data, dict)
+ else {}
+ )
messages = (
data.get("messages", [])
- or data.get("data", {}).get("messages", [])
+ or (
+ inner_data.get("messages", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("messages", [])
or data.get("emails", [])
+ or (
+ inner_data.get("emails", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("emails", [])
)
# Check for pagination token in various possible locations
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
- or data.get("data", {}).get("nextPageToken")
- or data.get("data", {}).get("next_page_token")
+ or (
+ inner_data.get("nextPageToken")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or (
+ inner_data.get("next_page_token")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or response_data.get("nextPageToken")
+ or response_data.get("next_page_token")
)
# Extract resultSizeEstimate if available (Gmail API provides this)
result_size_estimate = (
data.get("resultSizeEstimate")
or data.get("result_size_estimate")
- or data.get("data", {}).get("resultSizeEstimate")
- or data.get("data", {}).get("result_size_estimate")
+ or (
+ inner_data.get("resultSizeEstimate")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or (
+ inner_data.get("result_size_estimate")
+ if isinstance(inner_data, dict)
+ else None
+ )
+ or response_data.get("resultSizeEstimate")
+ or response_data.get("result_size_estimate")
)
elif isinstance(data, list):
messages = data
@@ -864,7 +926,7 @@ class ComposioService:
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
- tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
+ tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
params={"message_id": message_id}, # snake_case
entity_id=entity_id,
)
@@ -872,7 +934,13 @@ class ComposioService:
if not result.get("success"):
return None, result.get("error", "Unknown error")
- return result.get("data"), None
+ data = result.get("data")
+ if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ return inner_data.get("response_data", inner_data), None
+
+ return data, None
except Exception as e:
logger.error(f"Failed to get Gmail message detail: {e!s}")
@@ -928,10 +996,27 @@ class ComposioService:
# Try different possible response structures
events = []
if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ response_data = (
+ inner_data.get("response_data", {})
+ if isinstance(inner_data, dict)
+ else {}
+ )
events = (
data.get("items", [])
- or data.get("data", {}).get("items", [])
+ or (
+ inner_data.get("items", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("items", [])
or data.get("events", [])
+ or (
+ inner_data.get("events", [])
+ if isinstance(inner_data, dict)
+ else []
+ )
+ or response_data.get("events", [])
)
elif isinstance(data, list):
events = data
diff --git a/surfsense_backend/app/services/gmail/tool_metadata_service.py b/surfsense_backend/app/services/gmail/tool_metadata_service.py
index c903e24af..4855c1cc9 100644
--- a/surfsense_backend/app/services/gmail/tool_metadata_service.py
+++ b/surfsense_backend/app/services/gmail/tool_metadata_service.py
@@ -17,7 +17,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
-from app.utils.google_credentials import build_composio_credentials
+from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@@ -78,14 +78,49 @@ class GmailToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
- async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
- if (
+ def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
+ return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- return build_composio_credentials(cca_id)
+ )
+
+ def _get_composio_connected_account_id(
+ self, connector: SearchSourceConnector
+ ) -> str:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ raise ValueError("Composio connected_account_id not found")
+ return cca_id
+
+ def _unwrap_composio_data(self, data: Any) -> Any:
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner)
+ return inner
+ return data
+
+ async def _execute_composio_gmail_tool(
+ self,
+ connector: SearchSourceConnector,
+ tool_name: str,
+ params: dict[str, Any],
+ ) -> tuple[Any, str | None]:
+ result = await ComposioService().execute_tool(
+ connected_account_id=self._get_composio_connected_account_id(connector),
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{connector.user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Gmail error")
+ return self._unwrap_composio_data(result.get("data")), None
+
+ async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
+ if self._is_composio_connector(connector):
+ raise ValueError(
+ "Composio Gmail connectors must use Composio tool execution"
+ )
config_data = dict(connector.config)
@@ -139,6 +174,12 @@ class GmailToolMetadataService:
if not connector:
return True
+ if self._is_composio_connector(connector):
+ _profile, error = await self._execute_composio_gmail_tool(
+ connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
+ )
+ return bool(error)
+
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
@@ -221,14 +262,21 @@ class GmailToolMetadataService:
)
connector = result.scalar_one_or_none()
if connector:
- creds = await self._build_credentials(connector)
- service = build("gmail", "v1", credentials=creds)
- profile = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda service=service: (
- service.users().getProfile(userId="me").execute()
- ),
- )
+ if self._is_composio_connector(connector):
+ profile, error = await self._execute_composio_gmail_tool(
+ connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ creds = await self._build_credentials(connector)
+ service = build("gmail", "v1", credentials=creds)
+ profile = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda service=service: (
+ service.users().getProfile(userId="me").execute()
+ ),
+ )
acc_dict["email"] = profile.get("emailAddress", "")
except Exception:
logger.warning(
@@ -298,6 +346,23 @@ class GmailToolMetadataService:
Returns ``None`` on any failure so callers can degrade gracefully.
"""
try:
+ if self._is_composio_connector(connector):
+ if not draft_id:
+ draft_id = await self._find_composio_draft_id(connector, message_id)
+ if not draft_id:
+ return None
+
+ draft, error = await self._execute_composio_gmail_tool(
+ connector,
+ "GMAIL_GET_DRAFT",
+ {"user_id": "me", "draft_id": draft_id, "format": "full"},
+ )
+ if error or not isinstance(draft, dict):
+ return None
+
+ payload = draft.get("message", {}).get("payload", {})
+ return self._extract_body_from_payload(payload)
+
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
@@ -326,6 +391,33 @@ class GmailToolMetadataService:
)
return None
+ async def _find_composio_draft_id(
+ self, connector: SearchSourceConnector, message_id: str
+ ) -> str | None:
+ page_token = ""
+ while True:
+ params: dict[str, Any] = {
+ "user_id": "me",
+ "max_results": 100,
+ "verbose": False,
+ }
+ if page_token:
+ params["page_token"] = page_token
+
+ data, error = await self._execute_composio_gmail_tool(
+ connector, "GMAIL_LIST_DRAFTS", params
+ )
+ if error or not isinstance(data, dict):
+ return None
+
+ for draft in data.get("drafts", []):
+ if draft.get("message", {}).get("id") == message_id:
+ return draft.get("id")
+
+ page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
+ if not page_token:
+ return None
+
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
"""Resolve a draft ID from its message ID by scanning drafts.list."""
try:
diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py
index 20426f3bc..602a55738 100644
--- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py
+++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py
@@ -14,6 +14,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
+from app.services.composio_service import ComposioService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@@ -21,7 +22,6 @@ from app.utils.document_converters import (
generate_document_summary,
generate_unique_identifier_hash,
)
-from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
@@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService:
logger.warning("Document %s not found in KB", document_id)
return {"status": "not_indexed"}
- creds = await self._build_credentials_for_connector(connector_id)
- loop = asyncio.get_event_loop()
- service = await loop.run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
-
calendar_id = (document.document_metadata or {}).get(
"calendar_id"
) or "primary"
- live_event = await loop.run_in_executor(
- None,
- lambda: (
- service.events()
- .get(calendarId=calendar_id, eventId=event_id)
- .execute()
- ),
- )
+ connector = await self._get_connector(connector_id)
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ ):
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ raise ValueError("Composio connected_account_id not found")
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_EVENTS_GET",
+ params={"calendar_id": calendar_id, "event_id": event_id},
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get("error", "Unknown Composio Calendar error")
+ )
+ live_event = composio_result.get("data", {})
+ if isinstance(live_event, dict):
+ live_event = live_event.get("data", live_event)
+ if isinstance(live_event, dict):
+ live_event = live_event.get("response_data", live_event)
+ else:
+ creds = await self._build_credentials_for_connector(connector_id)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ live_event = await loop.run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .get(calendarId=calendar_id, eventId=event_id)
+ .execute()
+ ),
+ )
event_summary = live_event.get("summary", "")
description = live_event.get("description", "")
@@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
- async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
+ async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
result = await self.db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
@@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
connector = result.scalar_one_or_none()
if not connector:
raise ValueError(f"Connector {connector_id} not found")
+ return connector
+ async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
+ connector = await self._get_connector(connector_id)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- return build_composio_credentials(cca_id)
- raise ValueError("Composio connected_account_id not found")
+ raise ValueError(
+ "Composio Calendar connectors must use Composio tool execution"
+ )
config_data = dict(connector.config)
diff --git a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py
index c7bfe1d50..7e50ab039 100644
--- a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py
+++ b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py
@@ -16,7 +16,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
-from app.utils.google_credentials import build_composio_credentials
+from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
- async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
- if (
+ def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
+ return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- return build_composio_credentials(cca_id)
+ )
+
+ def _get_composio_connected_account_id(
+ self, connector: SearchSourceConnector
+ ) -> str:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
raise ValueError("Composio connected_account_id not found")
+ return cca_id
+
+ async def _execute_composio_calendar_tool(
+ self,
+ connector: SearchSourceConnector,
+ tool_name: str,
+ params: dict,
+ ) -> tuple[dict | list | None, str | None]:
+ service = ComposioService()
+ result = await service.execute_tool(
+ connected_account_id=self._get_composio_connected_account_id(connector),
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{connector.user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Calendar error")
+
+ data = result.get("data")
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner), None
+ return inner, None
+ return data, None
+
+ async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
+ if self._is_composio_connector(connector):
+ raise ValueError(
+ "Composio Calendar connectors must use Composio tool execution"
+ )
config_data = dict(connector.config)
@@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
if not connector:
return True
+ if self._is_composio_connector(connector):
+ _data, error = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_GET_CALENDAR",
+ {"calendar_id": "primary"},
+ )
+ return bool(error)
+
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
@@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService:
timezone_str = ""
if connector:
try:
- creds = await self._build_credentials(connector)
- loop = asyncio.get_event_loop()
- service = await loop.run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
+ if self._is_composio_connector(connector):
+ cal_list, cal_error = await self._execute_composio_calendar_tool(
+ connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
+ )
+ if cal_error:
+ raise RuntimeError(cal_error)
+ (
+ settings,
+ settings_error,
+ ) = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_SETTINGS_GET",
+ {"setting": "timezone"},
+ )
+ if not settings_error and isinstance(settings, dict):
+ timezone_str = settings.get("value", "")
+ else:
+ creds = await self._build_credentials(connector)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
- cal_list = await loop.run_in_executor(
- None, lambda: service.calendarList().list().execute()
- )
- for cal in cal_list.get("items", []):
+ cal_list = await loop.run_in_executor(
+ None, lambda: service.calendarList().list().execute()
+ )
+
+ tz_setting = await loop.run_in_executor(
+ None,
+ lambda: service.settings().get(setting="timezone").execute(),
+ )
+ timezone_str = tz_setting.get("value", "")
+
+ calendar_items = []
+ if isinstance(cal_list, dict):
+ calendar_items = (
+ cal_list.get("items") or cal_list.get("calendars") or []
+ )
+ elif isinstance(cal_list, list):
+ calendar_items = cal_list
+
+ for cal in calendar_items:
calendars.append(
{
"id": cal.get("id", ""),
@@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
"primary": cal.get("primary", False),
}
)
-
- tz_setting = await loop.run_in_executor(
- None,
- lambda: service.settings().get(setting="timezone").execute(),
- )
- timezone_str = tz_setting.get("value", "")
except Exception:
logger.warning(
"Failed to fetch calendars/timezone for connector %s",
@@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService:
event_dict = event.to_dict()
try:
- creds = await self._build_credentials(connector)
- loop = asyncio.get_event_loop()
- service = await loop.run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
calendar_id = event.calendar_id or "primary"
- live_event = await loop.run_in_executor(
- None,
- lambda: (
- service.events()
- .get(calendarId=calendar_id, eventId=event.event_id)
- .execute()
- ),
- )
+ if self._is_composio_connector(connector):
+ live_event, error = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_EVENTS_GET",
+ {"calendar_id": calendar_id, "event_id": event.event_id},
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ creds = await self._build_credentials(connector)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ live_event = await loop.run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .get(calendarId=calendar_id, eventId=event.event_id)
+ .execute()
+ ),
+ )
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
event_dict["description"] = live_event.get(
@@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService:
) -> dict:
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
if not resolved:
+ live_resolved = await self._resolve_live_event(
+ search_space_id, user_id, event_ref
+ )
+ if not live_resolved:
+ return {
+ "error": (
+ f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
+ "This could mean: (1) the event doesn't exist, "
+ "(2) the event name is different, or "
+ "(3) the connected calendar account cannot access it."
+ )
+ }
+
+ connector, live_event = live_resolved
+ account = GoogleCalendarAccount.from_connector(connector)
+ acc_dict = account.to_dict()
+ auth_expired = await self._check_account_health(connector.id)
+ acc_dict["auth_expired"] = auth_expired
+ if auth_expired:
+ await self._persist_auth_expired(connector.id)
+
return {
- "error": (
- f"Event '{event_ref}' not found in your indexed Google Calendar events. "
- "This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
- "or (3) the event name is different."
- )
+ "account": acc_dict,
+ "event": self._event_dict_from_live_event(live_event),
}
document, connector = resolved
@@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
if row:
return row[0], row[1]
return None
+
+ async def _resolve_live_event(
+ self, search_space_id: int, user_id: str, event_ref: str
+ ) -> tuple[SearchSourceConnector, dict] | None:
+ result = await self._db_session.execute(
+ select(SearchSourceConnector)
+ .filter(
+ and_(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
+ )
+ )
+ .order_by(SearchSourceConnector.last_indexed_at.desc())
+ )
+ connectors = result.scalars().all()
+
+ for connector in connectors:
+ try:
+ events = await self._search_live_events(connector, event_ref)
+ except Exception:
+ logger.warning(
+ "Failed to search live calendar events for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ continue
+
+ if not events:
+ continue
+
+ normalized_ref = event_ref.strip().lower()
+ exact_match = next(
+ (
+ event
+ for event in events
+ if event.get("summary", "").strip().lower() == normalized_ref
+ ),
+ None,
+ )
+ return connector, exact_match or events[0]
+
+ return None
+
+ async def _search_live_events(
+ self, connector: SearchSourceConnector, event_ref: str
+ ) -> list[dict]:
+ if self._is_composio_connector(connector):
+ data, error = await self._execute_composio_calendar_tool(
+ connector,
+ "GOOGLECALENDAR_EVENTS_LIST",
+ {
+ "calendar_id": "primary",
+ "q": event_ref,
+ "max_results": 10,
+ "single_events": True,
+ "order_by": "startTime",
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if isinstance(data, dict):
+ return data.get("items") or data.get("events") or []
+ return data if isinstance(data, list) else []
+
+ creds = await self._build_credentials(connector)
+ loop = asyncio.get_event_loop()
+ service = await loop.run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ response = await loop.run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .list(
+ calendarId="primary",
+ q=event_ref,
+ maxResults=10,
+ singleEvents=True,
+ orderBy="startTime",
+ )
+ .execute()
+ ),
+ )
+ return response.get("items", [])
+
+ def _event_dict_from_live_event(self, event: dict) -> dict:
+ start_data = event.get("start", {})
+ end_data = event.get("end", {})
+ return {
+ "event_id": event.get("id", ""),
+ "summary": event.get("summary", "No Title"),
+ "start": start_data.get("dateTime", start_data.get("date", "")),
+ "end": end_data.get("dateTime", end_data.get("date", "")),
+ "description": event.get("description", ""),
+ "location": event.get("location", ""),
+ "attendees": [
+ {
+ "email": attendee.get("email", ""),
+ "responseStatus": attendee.get("responseStatus", ""),
+ }
+ for attendee in event.get("attendees", [])
+ ],
+ "calendar_id": event.get("calendarId", "primary"),
+ "document_id": None,
+ "indexed_at": None,
+ }
diff --git a/surfsense_backend/app/services/google_drive/tool_metadata_service.py b/surfsense_backend/app/services/google_drive/tool_metadata_service.py
index 221bee14a..0f654bc78 100644
--- a/surfsense_backend/app/services/google_drive/tool_metadata_service.py
+++ b/surfsense_backend/app/services/google_drive/tool_metadata_service.py
@@ -13,7 +13,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
-from app.utils.google_credentials import build_composio_credentials
+from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
+ def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
+ return (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
+ )
+
+ def _get_composio_connected_account_id(
+ self, connector: SearchSourceConnector
+ ) -> str:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ raise ValueError("Composio connected_account_id not found")
+ return cca_id
+
+ async def _execute_composio_drive_tool(
+ self,
+ connector: SearchSourceConnector,
+ tool_name: str,
+ params: dict,
+ ) -> tuple[dict | list | None, str | None]:
+ result = await ComposioService().execute_tool(
+ connected_account_id=self._get_composio_connected_account_id(connector),
+ tool_name=tool_name,
+ params=params,
+ entity_id=f"surfsense_{connector.user_id}",
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown Composio Drive error")
+ data = result.get("data")
+ if isinstance(data, dict):
+ inner = data.get("data", data)
+ if isinstance(inner, dict):
+ return inner.get("response_data", inner), None
+ return inner, None
+ return data, None
+
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
@@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
if not connector:
return True
- pre_built_creds = None
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
+ if self._is_composio_connector(connector):
+ _data, error = await self._execute_composio_drive_tool(
+ connector,
+ "GOOGLEDRIVE_LIST_FILES",
+ {
+ "q": "trashed = false",
+ "page_size": 1,
+ "fields": "files(id)",
+ },
+ )
+ return bool(error)
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
- credentials=pre_built_creds,
)
await client.list_files(
query="trashed = false", page_size=1, fields="files(id)"
@@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
parent_folders[connector_id] = []
continue
- pre_built_creds = None
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if cca_id:
- pre_built_creds = build_composio_credentials(cca_id)
+ if self._is_composio_connector(connector):
+ data, error = await self._execute_composio_drive_tool(
+ connector,
+ "GOOGLEDRIVE_LIST_FILES",
+ {
+ "q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
+ "fields": "files(id,name)",
+ "page_size": 50,
+ },
+ )
+ if error:
+ logger.warning(
+ "Failed to list folders for connector %s: %s",
+ connector_id,
+ error,
+ )
+ parent_folders[connector_id] = []
+ continue
+ folders = []
+ if isinstance(data, dict):
+ folders = data.get("files", [])
+ elif isinstance(data, list):
+ folders = data
+ parent_folders[connector_id] = [
+ {"folder_id": f["id"], "name": f["name"]}
+ for f in folders
+ if f.get("id") and f.get("name")
+ ]
+ continue
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
- credentials=pre_built_creds,
)
folders, _, error = await client.list_files(
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index c6ac3311a..5eb35f8b1 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -96,6 +96,46 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
+def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
+ """Return the first LangGraph interrupt payload across all snapshot tasks."""
+ def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
+ if isinstance(candidate, dict):
+ value = candidate.get("value", candidate)
+ return value if isinstance(value, dict) else None
+ value = getattr(candidate, "value", None)
+ if isinstance(value, dict):
+ return value
+ if isinstance(candidate, (list, tuple)):
+ for item in candidate:
+ extracted = _extract_interrupt_value(item)
+ if extracted is not None:
+ return extracted
+ return None
+
+ for task in getattr(state, "tasks", ()) or ():
+ try:
+ interrupts = getattr(task, "interrupts", ()) or ()
+ except (AttributeError, IndexError, TypeError):
+ interrupts = ()
+ if not interrupts:
+ extracted = _extract_interrupt_value(task)
+ if extracted is not None:
+ return extracted
+ continue
+ for interrupt_item in interrupts:
+ extracted = _extract_interrupt_value(interrupt_item)
+ if extracted is not None:
+ return extracted
+ try:
+ state_interrupts = getattr(state, "interrupts", ()) or ()
+ except (AttributeError, IndexError, TypeError):
+ state_interrupts = ()
+ extracted = _extract_interrupt_value(state_interrupts)
+ if extracted is not None:
+ return extracted
+ return None
+
+
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
@@ -2178,10 +2218,10 @@ async def _stream_agent_events(
result.agent_called_update_memory = called_update_memory
_log_file_contract("turn_outcome", result)
- is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
- if is_interrupted:
+ interrupt_value = _first_interrupt_value(state)
+ if interrupt_value is not None:
result.is_interrupted = True
- result.interrupt_value = state.tasks[0].interrupts[0].value
+ result.interrupt_value = interrupt_value
yield streaming_service.format_interrupt_request(result.interrupt_value)
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
index 6912ffe5a..3c9f27303 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
@@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
+from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
-from app.utils.google_credentials import (
- COMPOSIO_GOOGLE_CONNECTOR_TYPES,
- build_composio_credentials,
-)
+from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
from .base import (
check_duplicate_document_by_hash,
@@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
+def _format_calendar_event_to_markdown(event: dict) -> str:
+ return GoogleCalendarConnector.format_event_to_markdown(None, event)
+
+
def _build_connector_doc(
event: dict,
event_markdown: str,
@@ -150,7 +152,14 @@ async def index_google_calendar_events(
)
return 0, 0, f"Connector with ID {connector_id} not found"
- # ── Credential building ───────────────────────────────────────
+ is_composio_connector = (
+ connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
+ )
+ calendar_client = None
+ composio_service = None
+ connected_account_id = None
+
+ # ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@@ -161,7 +170,7 @@ async def index_google_calendar_events(
{"error_type": "MissingComposioAccount"},
)
return 0, 0, "Composio connected_account_id not found"
- credentials = build_composio_credentials(connected_account_id)
+ composio_service = ComposioService()
else:
config_data = connector.config
@@ -229,12 +238,13 @@ async def index_google_calendar_events(
{"stage": "client_initialization"},
)
- calendar_client = GoogleCalendarConnector(
- credentials=credentials,
- session=session,
- user_id=user_id,
- connector_id=connector_id,
- )
+ if not is_composio_connector:
+ calendar_client = GoogleCalendarConnector(
+ credentials=credentials,
+ session=session,
+ user_id=user_id,
+ connector_id=connector_id,
+ )
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
@@ -300,9 +310,26 @@ async def index_google_calendar_events(
)
try:
- events, error = await calendar_client.get_all_primary_calendar_events(
- start_date=start_date_str, end_date=end_date_str
- )
+ if is_composio_connector:
+ start_dt = parse_date_flexible(start_date_str).replace(
+ hour=0, minute=0, second=0, microsecond=0
+ )
+ end_dt = parse_date_flexible(end_date_str).replace(
+ hour=23, minute=59, second=59, microsecond=0
+ )
+ events, error = await composio_service.get_calendar_events(
+ connected_account_id=connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ time_min=start_dt.isoformat(),
+ time_max=end_dt.isoformat(),
+ max_results=250,
+ )
+ if not events and not error:
+ error = "No events found in the specified date range."
+ else:
+ events, error = await calendar_client.get_all_primary_calendar_events(
+ start_date=start_date_str, end_date=end_date_str
+ )
if error:
if "No events found" in error:
@@ -381,7 +408,7 @@ async def index_google_calendar_events(
documents_skipped += 1
continue
- event_markdown = calendar_client.format_event_to_markdown(event)
+ event_markdown = _format_calendar_event_to_markdown(event)
if not event_markdown.strip():
logger.warning(f"Skipping event with no content: {event_summary}")
documents_skipped += 1
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
index 21cdbd29f..686f13d9e 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
@@ -9,6 +9,8 @@ import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
+from pathlib import Path
+from typing import Any
from sqlalchemy import String, cast, select
from sqlalchemy.exc import SQLAlchemyError
@@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
+from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService
@@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
get_connector_by_id,
update_connector_last_indexed,
)
-from app.utils.google_credentials import (
- COMPOSIO_GOOGLE_CONNECTOR_TYPES,
- build_composio_credentials,
-)
+from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
@@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
logger = logging.getLogger(__name__)
+class ComposioDriveClient:
+ """Google Drive client facade backed by Composio tool execution.
+
+ Composio-managed OAuth connections can execute tools without exposing raw
+ OAuth tokens through connected account state.
+ """
+
+ def __init__(
+ self,
+ session: AsyncSession,
+ connector_id: int,
+ connected_account_id: str,
+ entity_id: str,
+ ):
+ self.session = session
+ self.connector_id = connector_id
+ self.connected_account_id = connected_account_id
+ self.entity_id = entity_id
+ self.composio = ComposioService()
+
+ async def list_files(
+ self,
+ query: str = "",
+ fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
+ page_size: int = 100,
+ page_token: str | None = None,
+ ) -> tuple[list[dict[str, Any]], str | None, str | None]:
+ params: dict[str, Any] = {
+ "page_size": min(page_size, 100),
+ "fields": fields,
+ }
+ if query:
+ params["q"] = query
+ if page_token:
+ params["page_token"] = page_token
+
+ result = await self.composio.execute_tool(
+ connected_account_id=self.connected_account_id,
+ tool_name="GOOGLEDRIVE_LIST_FILES",
+ params=params,
+ entity_id=self.entity_id,
+ )
+ if not result.get("success"):
+ return [], None, result.get("error", "Unknown error")
+
+ data = result.get("data", {})
+ files = []
+ next_token = None
+ if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ files = inner_data.get("files", [])
+ next_token = inner_data.get("nextPageToken") or inner_data.get(
+ "next_page_token"
+ )
+ elif isinstance(data, list):
+ files = data
+
+ return files, next_token, None
+
+ async def get_file_metadata(
+ self, file_id: str, fields: str = "*"
+ ) -> tuple[dict[str, Any] | None, str | None]:
+ result = await self.composio.execute_tool(
+ connected_account_id=self.connected_account_id,
+ tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
+ params={"file_id": file_id, "fields": fields},
+ entity_id=self.entity_id,
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown error")
+
+ data = result.get("data", {})
+ if isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ return inner_data, None
+
+ return None, "Could not extract metadata from Composio response"
+
+ async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
+ return await self._download_file_content(file_id)
+
+ async def download_file_to_disk(
+ self,
+ file_id: str,
+ dest_path: str,
+ chunksize: int = 5 * 1024 * 1024,
+ ) -> str | None:
+ del chunksize
+ content, error = await self.download_file(file_id)
+ if error:
+ return error
+ if content is None:
+ return "No content returned from Composio"
+ Path(dest_path).write_bytes(content)
+ return None
+
+ async def export_google_file(
+ self, file_id: str, mime_type: str
+ ) -> tuple[bytes | None, str | None]:
+ return await self._download_file_content(file_id, mime_type=mime_type)
+
+ async def _download_file_content(
+ self, file_id: str, mime_type: str | None = None
+ ) -> tuple[bytes | None, str | None]:
+ params: dict[str, Any] = {"file_id": file_id}
+ if mime_type:
+ params["mime_type"] = mime_type
+
+ result = await self.composio.execute_tool(
+ connected_account_id=self.connected_account_id,
+ tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
+ params=params,
+ entity_id=self.entity_id,
+ )
+ if not result.get("success"):
+ return None, result.get("error", "Unknown error")
+
+ return self._read_download_result(result.get("data"))
+
+ def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
+ if isinstance(data, bytes):
+ return data, None
+
+ file_path: str | None = None
+ if isinstance(data, str):
+ file_path = data
+ elif isinstance(data, dict):
+ inner_data = data.get("data", data)
+ if isinstance(inner_data, dict):
+ for key in ("file_path", "downloaded_file_content", "path", "uri"):
+ value = inner_data.get(key)
+ if isinstance(value, str):
+ file_path = value
+ break
+ if isinstance(value, dict):
+ nested = (
+ value.get("file_path")
+ or value.get("downloaded_file_content")
+ or value.get("path")
+ or value.get("uri")
+ or value.get("s3url")
+ )
+ if isinstance(nested, str):
+ file_path = nested
+ break
+
+ if not file_path:
+ return None, "No file path/content returned from Composio"
+
+ if file_path.startswith(("http://", "https://")):
+ try:
+ import urllib.request
+
+ with urllib.request.urlopen(file_path, timeout=60) as response:
+ return response.read(), None
+ except Exception as e:
+ return None, f"Failed to download Composio file URL: {e!s}"
+
+ path_obj = Path(file_path)
+ if path_obj.is_absolute() or ".composio" in str(path_obj):
+ if not path_obj.exists():
+ return None, f"File not found at path: {file_path}"
+ return path_obj.read_bytes(), None
+
+ try:
+ import base64
+
+ return base64.b64decode(file_path), None
+ except Exception:
+ return file_path.encode("utf-8"), None
+
+
+def _build_drive_client_for_connector(
+ session: AsyncSession,
+ connector_id: int,
+ connector: object,
+ user_id: str,
+) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
+ if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
+ connected_account_id = connector.config.get("composio_connected_account_id")
+ if not connected_account_id:
+ return None, (
+ f"Composio connected_account_id not found for connector {connector_id}"
+ )
+ return (
+ ComposioDriveClient(
+ session,
+ connector_id,
+ connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ ),
+ None,
+ )
+
+ token_encrypted = connector.config.get("_token_encrypted", False)
+ if token_encrypted and not config.SECRET_KEY:
+ return None, "SECRET_KEY not configured but credentials are marked as encrypted"
+
+ return GoogleDriveClient(session, connector_id), None
+
+
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -927,34 +1130,17 @@ async def index_google_drive_files(
{"stage": "client_initialization"},
)
- pre_built_credentials = None
- if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- connected_account_id = connector.config.get("composio_connected_account_id")
- if not connected_account_id:
- error_msg = f"Composio connected_account_id not found for connector {connector_id}"
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing Composio account",
- {"error_type": "MissingComposioAccount"},
- )
- return 0, 0, error_msg, 0
- pre_built_credentials = build_composio_credentials(connected_account_id)
- else:
- token_encrypted = connector.config.get("_token_encrypted", False)
- if token_encrypted and not config.SECRET_KEY:
- await task_logger.log_task_failure(
- log_entry,
- "SECRET_KEY not configured but credentials are encrypted",
- "Missing SECRET_KEY",
- {"error_type": "MissingSecretKey"},
- )
- return (
- 0,
- 0,
- "SECRET_KEY not configured but credentials are marked as encrypted",
- 0,
- )
+ drive_client, client_error = _build_drive_client_for_connector(
+ session, connector_id, connector, user_id
+ )
+ if client_error or not drive_client:
+ await task_logger.log_task_failure(
+ log_entry,
+ client_error or "Failed to initialize Google Drive client",
+ "Missing connector credentials",
+ {"error_type": "ClientInitializationError"},
+ )
+ return 0, 0, client_error, 0
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@@ -963,10 +1149,6 @@ async def index_google_drive_files(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
- drive_client = GoogleDriveClient(
- session, connector_id, credentials=pre_built_credentials
- )
-
if not folder_id:
error_msg = "folder_id is required for Google Drive indexing"
await task_logger.log_task_failure(
@@ -979,8 +1161,14 @@ async def index_google_drive_files(
folder_tokens = connector.config.get("folder_tokens", {})
start_page_token = folder_tokens.get(target_folder_id)
+ is_composio_connector = (
+ connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
+ )
can_use_delta = (
- use_delta_sync and start_page_token and connector.last_indexed_at
+ not is_composio_connector
+ and use_delta_sync
+ and start_page_token
+ and connector.last_indexed_at
)
documents_unsupported = 0
@@ -1051,7 +1239,16 @@ async def index_google_drive_files(
)
if documents_indexed > 0 or can_use_delta:
- new_token, token_error = await get_start_page_token(drive_client)
+ if isinstance(drive_client, ComposioDriveClient):
+ (
+ new_token,
+ token_error,
+ ) = await drive_client.composio.get_drive_start_page_token(
+ drive_client.connected_account_id,
+ drive_client.entity_id,
+ )
+ else:
+ new_token, token_error = await get_start_page_token(drive_client)
if new_token and not token_error:
await session.refresh(connector)
if "folder_tokens" not in connector.config:
@@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
)
return 0, error_msg
- pre_built_credentials = None
- if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- connected_account_id = connector.config.get("composio_connected_account_id")
- if not connected_account_id:
- error_msg = f"Composio connected_account_id not found for connector {connector_id}"
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing Composio account",
- {"error_type": "MissingComposioAccount"},
- )
- return 0, error_msg
- pre_built_credentials = build_composio_credentials(connected_account_id)
- else:
- token_encrypted = connector.config.get("_token_encrypted", False)
- if token_encrypted and not config.SECRET_KEY:
- await task_logger.log_task_failure(
- log_entry,
- "SECRET_KEY not configured but credentials are encrypted",
- "Missing SECRET_KEY",
- {"error_type": "MissingSecretKey"},
- )
- return (
- 0,
- "SECRET_KEY not configured but credentials are marked as encrypted",
- )
+ drive_client, client_error = _build_drive_client_for_connector(
+ session, connector_id, connector, user_id
+ )
+ if client_error or not drive_client:
+ await task_logger.log_task_failure(
+ log_entry,
+ client_error or "Failed to initialize Google Drive client",
+ "Missing connector credentials",
+ {"error_type": "ClientInitializationError"},
+ )
+ return 0, client_error
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
- drive_client = GoogleDriveClient(
- session, connector_id, credentials=pre_built_credentials
- )
-
file, error = await get_file_by_id(drive_client, file_id)
if error or not file:
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
@@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files(
)
return 0, 0, [error_msg]
- pre_built_credentials = None
- if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
- connected_account_id = connector.config.get("composio_connected_account_id")
- if not connected_account_id:
- error_msg = f"Composio connected_account_id not found for connector {connector_id}"
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing Composio account",
- {"error_type": "MissingComposioAccount"},
- )
- return 0, 0, [error_msg]
- pre_built_credentials = build_composio_credentials(connected_account_id)
- else:
- token_encrypted = connector.config.get("_token_encrypted", False)
- if token_encrypted and not config.SECRET_KEY:
- error_msg = (
- "SECRET_KEY not configured but credentials are marked as encrypted"
- )
- await task_logger.log_task_failure(
- log_entry,
- error_msg,
- "Missing SECRET_KEY",
- {"error_type": "MissingSecretKey"},
- )
- return 0, 0, [error_msg]
+ drive_client, client_error = _build_drive_client_for_connector(
+ session, connector_id, connector, user_id
+ )
+ if client_error or not drive_client:
+ error_msg = client_error or "Failed to initialize Google Drive client"
+ await task_logger.log_task_failure(
+ log_entry,
+ error_msg,
+ "Missing connector credentials",
+ {"error_type": "ClientInitializationError"},
+ )
+ return 0, 0, [error_msg]
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
- drive_client = GoogleDriveClient(
- session, connector_id, credentials=pre_built_credentials
- )
-
indexed, skipped, unsupported, errors = await _index_selected_files(
drive_client,
session,
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
index ef226087b..6697c0eb1 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
@@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
+from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
-from app.utils.google_credentials import (
- COMPOSIO_GOOGLE_CONNECTOR_TYPES,
- build_composio_credentials,
-)
+from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
from .base import (
calculate_date_range,
@@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
+def _normalize_composio_gmail_message(message: dict) -> dict:
+ if message.get("payload"):
+ return message
+
+ headers = []
+ header_values = {
+ "Subject": message.get("subject"),
+ "From": message.get("from") or message.get("sender"),
+ "To": message.get("to") or message.get("recipient"),
+ "Date": message.get("date"),
+ }
+ for name, value in header_values.items():
+ if value:
+ headers.append({"name": name, "value": value})
+
+ return {
+ **message,
+ "id": message.get("id")
+ or message.get("message_id")
+ or message.get("messageId"),
+ "threadId": message.get("threadId") or message.get("thread_id"),
+ "payload": {"headers": headers},
+ "snippet": message.get("snippet", ""),
+ "messageText": message.get("messageText") or message.get("body") or "",
+ }
+
+
+def _format_gmail_message_to_markdown(message: dict) -> str:
+ headers = {
+ header.get("name", "").lower(): header.get("value", "")
+ for header in message.get("payload", {}).get("headers", [])
+ if isinstance(header, dict)
+ }
+ subject = headers.get("subject", "No Subject")
+ from_email = headers.get("from", "Unknown Sender")
+ to_email = headers.get("to", "Unknown Recipient")
+ date_str = headers.get("date", "Unknown Date")
+ message_text = (
+ message.get("messageText")
+ or message.get("body")
+ or message.get("text")
+ or message.get("snippet", "")
+ )
+
+ return (
+ f"# {subject}\n\n"
+ f"**From:** {from_email}\n"
+ f"**To:** {to_email}\n"
+ f"**Date:** {date_str}\n\n"
+ f"## Message Content\n\n{message_text}\n\n"
+ f"## Message Details\n\n"
+ f"- **Message ID:** {message.get('id', 'Unknown')}\n"
+ f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
+ )
+
+
def _build_connector_doc(
message: dict,
markdown_content: str,
@@ -162,7 +216,14 @@ async def index_google_gmail_messages(
)
return 0, 0, error_msg
- # ── Credential building ───────────────────────────────────────
+ is_composio_connector = (
+ connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
+ )
+ gmail_connector = None
+ composio_service = None
+ connected_account_id = None
+
+ # ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@@ -173,7 +234,7 @@ async def index_google_gmail_messages(
{"error_type": "MissingComposioAccount"},
)
return 0, 0, "Composio connected_account_id not found"
- credentials = build_composio_credentials(connected_account_id)
+ composio_service = ComposioService()
else:
config_data = connector.config
@@ -241,9 +302,10 @@ async def index_google_gmail_messages(
{"stage": "client_initialization"},
)
- gmail_connector = GoogleGmailConnector(
- credentials, session, user_id, connector_id
- )
+ if not is_composio_connector:
+ gmail_connector = GoogleGmailConnector(
+ credentials, session, user_id, connector_id
+ )
calculated_start_date, calculated_end_date = calculate_date_range(
connector, start_date, end_date, default_days_back=365
@@ -254,11 +316,60 @@ async def index_google_gmail_messages(
f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}"
)
- messages, error = await gmail_connector.get_recent_messages(
- max_results=max_messages,
- start_date=calculated_start_date,
- end_date=calculated_end_date,
- )
+ if is_composio_connector:
+ query_parts = []
+ if calculated_start_date:
+ query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
+ if calculated_end_date:
+ query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
+ query = " ".join(query_parts)
+
+ messages = []
+ page_token = None
+ error = None
+ while len(messages) < max_messages:
+ page_size = min(50, max_messages - len(messages))
+ (
+ page_messages,
+ page_token,
+ _estimate,
+ page_error,
+ ) = await composio_service.get_gmail_messages(
+ connected_account_id=connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ query=query,
+ max_results=page_size,
+ page_token=page_token,
+ )
+ if page_error:
+ error = page_error
+ break
+ for page_message in page_messages:
+ message_id = (
+ page_message.get("id")
+ or page_message.get("message_id")
+ or page_message.get("messageId")
+ )
+ if message_id:
+ (
+ detail,
+ detail_error,
+ ) = await composio_service.get_gmail_message_detail(
+ connected_account_id=connected_account_id,
+ entity_id=f"surfsense_{user_id}",
+ message_id=message_id,
+ )
+ if not detail_error and isinstance(detail, dict):
+ page_message = detail
+ messages.append(_normalize_composio_gmail_message(page_message))
+ if not page_token:
+ break
+ else:
+ messages, error = await gmail_connector.get_recent_messages(
+ max_results=max_messages,
+ start_date=calculated_start_date,
+ end_date=calculated_end_date,
+ )
if error:
error_message = error
@@ -326,7 +437,12 @@ async def index_google_gmail_messages(
documents_skipped += 1
continue
- markdown_content = gmail_connector.format_message_to_markdown(message)
+ if is_composio_connector:
+ markdown_content = _format_gmail_message_to_markdown(message)
+ else:
+ markdown_content = gmail_connector.format_message_to_markdown(
+ message
+ )
if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
index 9258d5cfe..0693dfebb 100644
--- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
+++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
@@ -51,22 +51,34 @@ class _FakeToolMessage:
tool_call_id: str | None = None
+@dataclass
+class _FakeInterrupt:
+ value: dict[str, Any]
+
+
+@dataclass
+class _FakeTask:
+ interrupts: tuple[_FakeInterrupt, ...] = ()
+
+
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
- def __init__(self) -> None:
+ def __init__(self, tasks: list[Any] | None = None) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
- # and an empty ``tasks`` list keeps the post-stream interrupt
- # check a no-op too.
+ # and empty ``tasks`` keep the post-stream interrupt check a no-op too.
self.values: dict[str, Any] = {}
- self.tasks: list[Any] = []
+ self.tasks: list[Any] = tasks or []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
- def __init__(self, events: list[dict[str, Any]]) -> None:
+ def __init__(
+ self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
+ ) -> None:
self._events = events
+ self._state = state or _FakeAgentState()
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
@@ -79,7 +91,7 @@ class _FakeAgent:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
- return _FakeAgentState()
+ return self._state
def _model_stream(
@@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
)
-async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
+async def _drain(
+ events: list[dict[str, Any]], state: _FakeAgentState | None = None
+) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
- agent = _FakeAgent(events)
+ agent = _FakeAgent(events, state=state)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
@@ -525,3 +539,29 @@ async def test_unmatched_fallback_still_attaches_lc_id(
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"
+
+
+@pytest.mark.asyncio
+async def test_interrupt_request_uses_task_that_contains_interrupt(
+ parity_v2_on: None,
+) -> None:
+ interrupt_payload = {
+ "type": "calendar_event_create",
+ "action": {
+ "tool": "create_calendar_event",
+ "params": {"summary": "mom bday"},
+ },
+ "context": {},
+ }
+ state = _FakeAgentState(
+ tasks=[
+ _FakeTask(interrupts=()),
+ _FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
+ ]
+ )
+
+ payloads = await _drain([], state=state)
+
+ interrupts = _of_type(payloads, "data-interrupt-request")
+ assert len(interrupts) == 1
+ assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
From bdb97a0888543ea5d5b8b3902efe1c3a808abf3f Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 22:25:04 -0700
Subject: [PATCH 05/12] chore: linting
---
surfsense_backend/app/tasks/chat/stream_new_chat.py | 1 +
.../tests/unit/tasks/chat/test_tool_input_streaming.py | 4 +++-
.../config/connector-status-config.json | 10 ++++++++++
.../connector-popup/constants/connector-constants.ts | 4 ++--
4 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index 5eb35f8b1..268a4401e 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -98,6 +98,7 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
+
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
if isinstance(candidate, dict):
value = candidate.get("value", candidate)
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
index 0693dfebb..60750396c 100644
--- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
+++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
@@ -564,4 +564,6 @@ async def test_interrupt_request_uses_task_that_contains_interrupt(
interrupts = _of_type(payloads, "data-interrupt-request")
assert len(interrupts) == 1
- assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
+ assert (
+ interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
+ )
diff --git a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json
index f62758256..b4e85eab0 100644
--- a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json
+++ b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json
@@ -9,6 +9,16 @@
"enabled": true,
"status": "warning",
"statusMessage": "Some requests may be blocked if not using Firecrawl."
+ },
+ "JIRA_CONNECTOR": {
+ "enabled": false,
+ "status": "maintenance",
+ "statusMessage": "Rework in progress."
+ },
+ "CONFLUENCE_CONNECTOR": {
+ "enabled": false,
+ "status": "maintenance",
+ "statusMessage": "Rework in progress."
}
},
"globalSettings": {
diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts
index ae2c413cf..2f9605ea7 100644
--- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts
+++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts
@@ -105,14 +105,14 @@ export const OAUTH_CONNECTORS = [
{
id: "jira-connector",
title: "Jira",
- description: "Search, read, and manage issues",
+ description: "Rework in progress.",
connectorType: EnumConnectorName.JIRA_CONNECTOR,
authEndpoint: "/api/v1/auth/mcp/jira/connector/add/",
},
{
id: "confluence-connector",
title: "Confluence",
- description: "Search documentation",
+ description: "Rework in progress.",
connectorType: EnumConnectorName.CONFLUENCE_CONNECTOR,
authEndpoint: "/api/v1/auth/confluence/connector/add/",
},
From c938d39277225f425cda24ea56ca50a0ed93e30a Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 23:10:48 -0700
Subject: [PATCH 06/12] feat: moved most things behind correct feature flag
---
docker/.env.example | 18 +++
.../app/agents/new_chat/feature_flags.py | 106 +++++++++++-------
.../app/routes/agent_flags_route.py | 8 +-
.../app/services/auto_model_pin_service.py | 2 +-
.../agents/new_chat/test_feature_flags.py | 38 ++++---
.../services/test_auto_model_pin_service.py | 55 ++++++++-
surfsense_web/app/(home)/pricing/page.tsx | 2 +-
.../new-chat/[[...chat_id]]/page.tsx | 18 ++-
.../components/AgentStatusContent.tsx | 13 +++
.../layout/ui/sidebar/DocumentsSidebar.tsx | 10 +-
.../components/pricing/pricing-section.tsx | 37 +++---
surfsense_web/lib/agent-filesystem.ts | 13 ++-
.../lib/apis/agent-flags-api.service.ts | 2 +
13 files changed, 237 insertions(+), 85 deletions(-)
diff --git a/docker/.env.example b/docker/.env.example
index c2e87a619..fd56bdccc 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -308,6 +308,24 @@ STT_SERVICE=local/base
# Advanced (optional)
# ------------------------------------------------------------------------------
+# New-chat agent feature flags
+SURFSENSE_ENABLE_CONTEXT_EDITING=true
+SURFSENSE_ENABLE_COMPACTION_V2=true
+SURFSENSE_ENABLE_RETRY_AFTER=true
+SURFSENSE_ENABLE_MODEL_FALLBACK=false
+SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
+SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
+SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
+SURFSENSE_ENABLE_BUSY_MUTEX=true
+SURFSENSE_ENABLE_SKILLS=true
+SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
+SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
+SURFSENSE_ENABLE_ACTION_LOG=true
+SURFSENSE_ENABLE_REVERT_ROUTE=true
+SURFSENSE_ENABLE_PERMISSION=true
+SURFSENSE_ENABLE_DOOM_LOOP=true
+SURFSENSE_ENABLE_STREAM_PARITY_V2=true
+
# Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m
diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py
index f58bf0dd7..5007d89a5 100644
--- a/surfsense_backend/app/agents/new_chat/feature_flags.py
+++ b/surfsense_backend/app/agents/new_chat/feature_flags.py
@@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode,
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
-SurfSense-native). They follow a "default-OFF for risky things,
-default-ON for safe upgrades, master kill-switch for everything new" model.
+SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
+image updates work even when older installs do not have newly introduced
+environment variables. Risky/experimental integrations stay default OFF,
+and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
@@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
Examples
--------
-Local development (recommended for trying everything except doom-loop / selector):
+Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
+ SURFSENSE_ENABLE_MODEL_FALLBACK=false
+ SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
+ SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
- SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
- SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
- SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
- SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
+ SURFSENSE_ENABLE_PERMISSION=true
+ SURFSENSE_ENABLE_DOOM_LOOP=true
+ SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
+ SURFSENSE_ENABLE_STREAM_PARITY_V2=true
Master kill-switch (overrides everything else):
@@ -60,32 +65,28 @@ class AgentFeatureFlags:
disable_new_agent_stack: bool = False
# Agent quality — context budget, retry/limits, name-repair, doom-loop
- enable_context_editing: bool = False
- enable_compaction_v2: bool = False
- enable_retry_after: bool = False
+ enable_context_editing: bool = True
+ enable_compaction_v2: bool = True
+ enable_retry_after: bool = True
enable_model_fallback: bool = False
- enable_model_call_limit: bool = False
- enable_tool_call_limit: bool = False
- enable_tool_call_repair: bool = False
- enable_doom_loop: bool = (
- False # Default OFF until UI handles permission='doom_loop'
- )
+ enable_model_call_limit: bool = True
+ enable_tool_call_limit: bool = True
+ enable_tool_call_repair: bool = True
+ enable_doom_loop: bool = True
# Safety — permissions, concurrency, tool-set narrowing
- enable_permission: bool = False # Default OFF for first deploy
- enable_busy_mutex: bool = False
+ enable_permission: bool = True
+ enable_busy_mutex: bool = True
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
# Skills + subagents
- enable_skills: bool = False
- enable_specialized_subagents: bool = False
- enable_kb_planner_runnable: bool = False
+ enable_skills: bool = True
+ enable_specialized_subagents: bool = True
+ enable_kb_planner_runnable: bool = True
# Snapshot / revert
- enable_action_log: bool = False
- enable_revert_route: bool = (
- False # Backend ships before UI; route returns 503 until this flips
- )
+ enable_action_log: bool = True
+ enable_revert_route: bool = True
# Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
@@ -94,7 +95,7 @@ class AgentFeatureFlags:
# text path and the synthetic ``call_`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible.
- enable_stream_parity_v2: bool = False
+ enable_stream_parity_v2: bool = True
# Plugins
enable_plugin_loader: bool = False
@@ -115,43 +116,64 @@ class AgentFeatureFlags:
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
"middleware is forced OFF for this build."
)
- return cls(disable_new_agent_stack=True)
+ return cls(
+ disable_new_agent_stack=True,
+ enable_context_editing=False,
+ enable_compaction_v2=False,
+ enable_retry_after=False,
+ enable_model_fallback=False,
+ enable_model_call_limit=False,
+ enable_tool_call_limit=False,
+ enable_tool_call_repair=False,
+ enable_doom_loop=False,
+ enable_permission=False,
+ enable_busy_mutex=False,
+ enable_llm_tool_selector=False,
+ enable_skills=False,
+ enable_specialized_subagents=False,
+ enable_kb_planner_runnable=False,
+ enable_action_log=False,
+ enable_revert_route=False,
+ enable_stream_parity_v2=False,
+ enable_plugin_loader=False,
+ enable_otel=False,
+ )
return cls(
disable_new_agent_stack=False,
# Agent quality
- enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
- enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
- enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
+ enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
+ enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
+ enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool(
- "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
+ "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
),
- enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
+ enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
enable_tool_call_repair=_env_bool(
- "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
+ "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
),
- enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
+ enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
# Safety
- enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
- enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
+ enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
+ enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
),
# Skills + subagents
- enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
+ enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
enable_specialized_subagents=_env_bool(
- "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
+ "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
),
enable_kb_planner_runnable=_env_bool(
- "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
+ "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
),
# Snapshot / revert
- enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
- enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
+ enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
+ enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
# Streaming parity v2
enable_stream_parity_v2=_env_bool(
- "SURFSENSE_ENABLE_STREAM_PARITY_V2", False
+ "SURFSENSE_ENABLE_STREAM_PARITY_V2", True
),
# Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py
index 5732a8dfb..99388af66 100644
--- a/surfsense_backend/app/routes/agent_flags_route.py
+++ b/surfsense_backend/app/routes/agent_flags_route.py
@@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
+from app.config import config
from app.db import User
from app.users import current_active_user
@@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool
+ enable_desktop_local_filesystem: bool
+
@classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
- return cls(**asdict(flags))
+ return cls(
+ **asdict(flags),
+ enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
+ )
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py
index 4f045ba02..185035b8a 100644
--- a/surfsense_backend/app/services/auto_model_pin_service.py
+++ b/surfsense_backend/app/services/auto_model_pin_service.py
@@ -399,7 +399,7 @@ async def resolve_or_get_pinned_llm_config_id(
False if force_repin_free else await _is_premium_eligible(session, user_id)
)
if premium_eligible:
- eligible = candidates
+ eligible = [c for c in candidates if _tier_of(c) == "premium"]
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
index 38a70a443..df60a4816 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
@@ -31,18 +31,38 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
+ "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
]:
monkeypatch.delenv(name, raising=False)
-def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
+def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
- assert flags.any_new_middleware_enabled() is False
+ assert flags.enable_context_editing is True
+ assert flags.enable_compaction_v2 is True
+ assert flags.enable_retry_after is True
+ assert flags.enable_model_fallback is False
+ assert flags.enable_model_call_limit is True
+ assert flags.enable_tool_call_limit is True
+ assert flags.enable_tool_call_repair is True
+ assert flags.enable_doom_loop is True
+ assert flags.enable_permission is True
+ assert flags.enable_busy_mutex is True
+ assert flags.enable_llm_tool_selector is False
+ assert flags.enable_skills is True
+ assert flags.enable_specialized_subagents is True
+ assert flags.enable_kb_planner_runnable is True
+ assert flags.enable_action_log is True
+ assert flags.enable_revert_route is True
+ assert flags.enable_stream_parity_v2 is True
+ assert flags.enable_plugin_loader is False
+ assert flags.enable_otel is False
+ assert flags.any_new_middleware_enabled() is True
def test_master_kill_switch_overrides_individual_flags(
@@ -100,21 +120,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
+ "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
- # `enable_otel` is intentionally orthogonal — it does NOT count toward
- # ``any_new_middleware_enabled`` because OTel is observability-only and
- # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
- counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
-
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
- monkeypatch.setenv(env_name, "true")
+ monkeypatch.setenv(env_name, "false")
flags = reload_for_tests()
- assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
- if attr in counts_toward_middleware:
- assert flags.any_new_middleware_enabled() is True
- else:
- assert flags.any_new_middleware_enabled() is False
+ assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"
diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
index 49b3621c7..c8d6dc1ca 100644
--- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
+++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
@@ -101,11 +101,58 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
- assert result.resolved_llm_config_id in {-1, -2}
+ assert result.resolved_llm_config_id == -1
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
assert session.commit_count == 1
+@pytest.mark.asyncio
+async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -2,
+ "provider": "OPENAI",
+ "model_name": "gpt-free",
+ "api_key": "k1",
+ "billing_tier": "free",
+ "quality_score": 100,
+ },
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-prem",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ "quality_score": 10,
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -1
+ assert result.resolved_tier == "premium"
+
+
@pytest.mark.asyncio
async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
@@ -361,12 +408,12 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
],
)
- async def _allowed(*_args, **_kwargs):
- return _FakeQuotaResult(allowed=True)
+ async def _blocked(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
- _allowed,
+ _blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx
index 6f332be70..2a413b9a9 100644
--- a/surfsense_web/app/(home)/pricing/page.tsx
+++ b/surfsense_web/app/(home)/pricing/page.tsx
@@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav";
export const metadata: Metadata = {
title: "Pricing | SurfSense - Free AI Search Plans",
description:
- "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.",
+ "Explore SurfSense plans and pricing. Start free with 500 pages & $5 in premium credits. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost.",
alternates: {
canonical: "https://surfsense.com/pricing",
},
diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
index 39201e5cc..4c8e4fe93 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
@@ -13,6 +13,7 @@ import { useParams } from "next/navigation";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { z } from "zod";
+import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
import {
clearTargetCommentIdAtom,
@@ -393,6 +394,8 @@ export default function NewChatPage() {
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
+ const { data: agentFlags } = useAtomValue(agentFlagsAtom);
+ const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true;
// Live collaboration: sync session state and messages via Zero
useChatSessionStateSync(threadId);
@@ -989,7 +992,9 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
- const selection = await getAgentFilesystemSelection(searchSpaceId);
+ const selection = await getAgentFilesystemSelection(searchSpaceId, {
+ localFilesystemEnabled,
+ });
if (
selection.filesystem_mode === "desktop_local_folder" &&
(!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0)
@@ -1311,6 +1316,7 @@ export default function NewChatPage() {
setAgentCreatedDocuments,
queryClient,
currentUser,
+ localFilesystemEnabled,
disabledTools,
updateChatTabTitle,
tokenUsageStore,
@@ -1413,7 +1419,9 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
- const selection = await getAgentFilesystemSelection(searchSpaceId);
+ const selection = await getAgentFilesystemSelection(searchSpaceId, {
+ localFilesystemEnabled,
+ });
const response = await fetchWithTurnCancellingRetry(() =>
fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
method: "POST",
@@ -1561,6 +1569,7 @@ export default function NewChatPage() {
pendingInterrupt,
messages,
searchSpaceId,
+ localFilesystemEnabled,
queryClient,
tokenUsageStore,
fetchWithTurnCancellingRetry,
@@ -1746,7 +1755,9 @@ export default function NewChatPage() {
? messageDocumentsMap[sourceUserMessageId]
: [];
try {
- const selection = await getAgentFilesystemSelection(searchSpaceId);
+ const selection = await getAgentFilesystemSelection(searchSpaceId, {
+ localFilesystemEnabled,
+ });
const requestBody: Record = {
search_space_id: searchSpaceId,
user_query: newUserQuery,
@@ -2016,6 +2027,7 @@ export default function NewChatPage() {
searchSpaceId,
messages,
disabledTools,
+ localFilesystemEnabled,
messageDocumentsMap,
setMessageDocumentsMap,
queryClient,
diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx
index bd8f03a70..17d8aa50c 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx
@@ -178,6 +178,19 @@ const FLAG_GROUPS: FlagGroup[] = [
},
],
},
+ {
+ id: "desktop",
+ title: "Desktop",
+ subtitle: "Desktop-only capabilities exposed by the backend deployment.",
+ flags: [
+ {
+ key: "enable_desktop_local_filesystem",
+ label: "Local filesystem",
+ description: "Allow Desktop chat sessions to operate directly on selected local folders.",
+ envVar: "ENABLE_DESKTOP_LOCAL_FILESYSTEM",
+ },
+ ],
+ },
];
function FlagRow({ def, value }: { def: FlagDef; value: boolean }) {
diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx
index bf4de6454..8d59363a6 100644
--- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx
+++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx
@@ -23,6 +23,7 @@ import { useTranslations } from "next-intl";
import type React from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
+import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
@@ -197,6 +198,7 @@ function AuthenticatedDocumentsSidebarBase({
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
const setRightPanelCollapsed = useSetAtom(rightPanelCollapsedAtom);
const openEditorPanel = useSetAtom(openEditorPanelAtom);
+ const { data: agentFlags } = useAtomValue(agentFlagsAtom);
const { data: connectors } = useAtomValue(connectorsAtom);
const connectorCount = connectors?.length ?? 0;
@@ -209,6 +211,7 @@ function AuthenticatedDocumentsSidebarBase({
const [watchedFolderIds, setWatchedFolderIds] = useState>(new Set());
const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom);
const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom);
+ const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true;
const isElectron =
desktopFeaturesEnabled && typeof window !== "undefined" && !!window.electronAPI;
@@ -1036,9 +1039,12 @@ function AuthenticatedDocumentsSidebarBase({
return () => document.removeEventListener("keydown", handleEscape);
}, [open, onOpenChange, isMobile, setRightPanelCollapsed]);
- const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings;
+ const showFilesystemTabs =
+ !isMobile && !!electronAPI && !!filesystemSettings && localFilesystemEnabled;
const currentFilesystemTab =
- filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud";
+ localFilesystemEnabled && filesystemSettings?.mode === "desktop_local_folder"
+ ? "local"
+ : "cloud";
const showCloudSkeleton =
currentFilesystemTab === "cloud" &&
(zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete");
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx
index 156ef9134..4ba1ecc1e 100644
--- a/surfsense_web/components/pricing/pricing-section.tsx
+++ b/surfsense_web/components/pricing/pricing-section.tsx
@@ -12,11 +12,11 @@ const demoPlans = [
price: "0",
yearlyPrice: "0",
period: "",
- billingText: "500 pages + $5 of premium credit included",
+ billingText: "500 pages + $5 in premium credits included",
features: [
"Self Hostable",
"500 pages included to start",
- "$5 of premium credit to start, billed at provider cost",
+ "$5 in premium credits for paid AI models and premium AI features",
"Includes access to OpenAI text, audio and image models",
"Realtime Collaborative Group Chats with teammates",
"Community support on Discord",
@@ -35,7 +35,7 @@ const demoPlans = [
features: [
"Everything in Free",
"Buy 1,000-page packs at $1 each",
- "Top up premium credit at $1 per $1 of credit, billed at provider cost",
+ "Top up premium credits at $1 per $1 of credit, billed at provider cost",
"Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter",
"Priority support on Discord",
],
@@ -89,7 +89,7 @@ const faqData: FAQSection[] = [
{
question: "What are Basic and Premium processing modes?",
answer:
- "When uploading documents, you can choose between two processing modes. Basic mode uses standard extraction and costs 1 page credit per page, great for most documents. Premium mode uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables, layouts, and formatting. Premium costs 10 page credits per page but delivers significantly higher fidelity output for these specialized document types.",
+ "When uploading documents, you can choose between two processing modes. Basic mode uses standard extraction and costs 1 page credit per page, great for most documents. Premium processing mode uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables, layouts, and formatting. It costs 10 page credits per page and does not use your premium AI credits.",
},
{
question: "How does the Pay As You Go plan work?",
@@ -129,27 +129,32 @@ const faqData: FAQSection[] = [
],
},
{
- title: "Premium Credit",
+ title: "Premium Credits",
items: [
{
- question: 'What is "premium credit"?',
+ question: 'What are "premium credits"?',
answer:
- "Premium credit is your USD balance for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request debits the actual USD cost the provider charges, so cheap and expensive models bill proportionally. Non-premium models (such as the free-tier models available without login) don't touch your premium credit.",
+ "Premium credits are your USD balance for paid AI usage in SurfSense, including premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro, plus premium AI features such as image generation, podcasts, and video presentations when they use paid models. Each request debits the actual USD provider cost, so cheaper and more expensive models bill proportionally.",
},
{
- question: "How much premium credit do I get for free?",
+ question: "How many premium credits do I get for free?",
answer:
- "Every registered SurfSense account starts with $5 of premium credit at no cost. Anonymous users (no login) get 500,000 free tokens across all free models. Once your free credit runs out, you can top up at any time.",
+ "Every registered SurfSense account starts with $5 in premium credits at no cost. Anonymous users (no login) get 500,000 free tokens across free models before creating an account. Once your included premium credits run out, you can top up at any time.",
},
{
- question: "How does buying premium credit work?",
+ question: "How does buying premium credits work?",
answer:
- "Just like pages, there's no subscription. Top-ups buy $1 of credit for $1 — every cent you pay is spent at provider cost, no markup. Purchased credit is added to your account immediately. You can buy up to $100 at a time.",
+ "Premium credit top-ups are pay as you go, with no subscription. $1 buys $1 of credit, and your balance is spent at provider cost. Purchased credit is added to your account immediately. You can buy up to $100 at a time.",
},
{
- question: "What happens if I run out of premium credit?",
+ question: "Are premium credits the same as page credits?",
answer:
- "When your premium credit balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you top up. You can always switch to non-premium models, which don't touch your premium credit.",
+ "No. Page credits pay for document indexing and file-based connector processing. Premium credits pay for paid AI usage, such as premium model chats and premium AI generation features. Premium document processing mode sounds similar, but it consumes page credits, not premium credits.",
+ },
+ {
+ question: "What happens if I run out of premium credits?",
+ answer:
+ "When your premium credit balance runs low, you'll see a warning. Once you run out, paid model requests and premium AI features are paused until you top up. You can still use non-premium models and features that do not consume premium credits.",
},
],
},
@@ -159,7 +164,7 @@ const faqData: FAQSection[] = [
{
question: "Can I self-host SurfSense with unlimited pages and credit?",
answer:
- "Yes! When self-hosting, you have full control over your page and premium-credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credit, so you can index as much data and use as many AI queries as your infrastructure supports.",
+ "Yes! When self-hosting, you have full control over your page and premium credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credits, so you can index as much data and use as many AI queries as your infrastructure supports.",
},
],
},
@@ -250,7 +255,7 @@ function PricingFAQ() {
Frequently Asked Questions
- Everything you need to know about SurfSense pages, premium credit, and billing. Can't
+ Everything you need to know about SurfSense pages, premium credits, and billing. Can't
find what you need? Reach out at{" "}
rohan@surfsense.com
@@ -335,7 +340,7 @@ function PricingBasic() {
>
diff --git a/surfsense_web/lib/agent-filesystem.ts b/surfsense_web/lib/agent-filesystem.ts
index da5fc1b1d..5f8066d27 100644
--- a/surfsense_web/lib/agent-filesystem.ts
+++ b/surfsense_web/lib/agent-filesystem.ts
@@ -12,6 +12,10 @@ export interface AgentFilesystemSelection {
local_filesystem_mounts?: AgentFilesystemMountSelection[];
}
+export interface AgentFilesystemSelectionOptions {
+ localFilesystemEnabled: boolean;
+}
+
const DEFAULT_SELECTION: AgentFilesystemSelection = {
filesystem_mode: "cloud",
client_platform: "web",
@@ -23,10 +27,15 @@ export function getClientPlatform(): ClientPlatform {
}
export async function getAgentFilesystemSelection(
- searchSpaceId?: number | null
+ searchSpaceId?: number | null,
+ options?: AgentFilesystemSelectionOptions
): Promise {
const platform = getClientPlatform();
- if (platform !== "desktop" || !window.electronAPI?.getAgentFilesystemSettings) {
+ if (
+ platform !== "desktop" ||
+ !options?.localFilesystemEnabled ||
+ !window.electronAPI?.getAgentFilesystemSettings
+ ) {
return { ...DEFAULT_SELECTION, client_platform: platform };
}
try {
diff --git a/surfsense_web/lib/apis/agent-flags-api.service.ts b/surfsense_web/lib/apis/agent-flags-api.service.ts
index 87332ca9f..534810c0e 100644
--- a/surfsense_web/lib/apis/agent-flags-api.service.ts
+++ b/surfsense_web/lib/apis/agent-flags-api.service.ts
@@ -27,6 +27,8 @@ const AgentFeatureFlagsSchema = z.object({
enable_plugin_loader: z.boolean(),
enable_otel: z.boolean(),
+
+ enable_desktop_local_filesystem: z.boolean(),
});
export type AgentFeatureFlags = z.infer;
From e4f9d79635d827adcc7e70279ccec9ac31482fa1 Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 23:35:47 -0700
Subject: [PATCH 07/12] feat: add preferred premium auto configuration logic
and corresponding tests
---
.../app/services/auto_model_pin_service.py | 15 ++++-
.../services/test_auto_model_pin_service.py | 58 +++++++++++++++++++
.../components/pricing/pricing-section.tsx | 3 +-
3 files changed, 73 insertions(+), 3 deletions(-)
diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py
index 185035b8a..9bbca8669 100644
--- a/surfsense_backend/app/services/auto_model_pin_service.py
+++ b/surfsense_backend/app/services/auto_model_pin_service.py
@@ -220,6 +220,15 @@ def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
+def _is_preferred_premium_auto_config(cfg: dict) -> bool:
+ """Return True for the operator-preferred premium Auto model."""
+ return (
+ _tier_of(cfg) == "premium"
+ and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
+ and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
+ )
+
+
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread.
@@ -399,7 +408,11 @@ async def resolve_or_get_pinned_llm_config_id(
False if force_repin_free else await _is_premium_eligible(session, user_id)
)
if premium_eligible:
- eligible = [c for c in candidates if _tier_of(c) == "premium"]
+ premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
+ preferred_premium = [
+ c for c in premium_candidates if _is_preferred_premium_auto_config(c)
+ ]
+ eligible = preferred_premium or premium_candidates
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]
diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
index c8d6dc1ca..d1af29aeb 100644
--- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
+++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
@@ -153,6 +153,64 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
assert result.resolved_tier == "premium"
+@pytest.mark.asyncio
+async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
+ from app.config import config
+
+ session = _FakeSession(_thread())
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.1",
+ "api_key": "k1",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 100,
+ },
+ {
+ "id": -2,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.4",
+ "api_key": "k2",
+ "billing_tier": "premium",
+ "auto_pin_tier": "A",
+ "quality_score": 10,
+ },
+ {
+ "id": -3,
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-5.4",
+ "api_key": "k3",
+ "billing_tier": "premium",
+ "auto_pin_tier": "B",
+ "quality_score": 100,
+ },
+ ],
+ )
+
+ async def _allowed(*_args, **_kwargs):
+ return _FakeQuotaResult(allowed=True)
+
+ monkeypatch.setattr(
+ "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
+ _allowed,
+ )
+
+ result = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=1,
+ search_space_id=10,
+ user_id="00000000-0000-0000-0000-000000000001",
+ selected_llm_config_id=0,
+ )
+ assert result.resolved_llm_config_id == -2
+ assert result.resolved_tier == "premium"
+
+
@pytest.mark.asyncio
async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx
index 4ba1ecc1e..07c11b4d6 100644
--- a/surfsense_web/components/pricing/pricing-section.tsx
+++ b/surfsense_web/components/pricing/pricing-section.tsx
@@ -34,8 +34,7 @@ const demoPlans = [
billingText: "No subscription, buy only when you need more",
features: [
"Everything in Free",
- "Buy 1,000-page packs at $1 each",
- "Top up premium credits at $1 per $1 of credit, billed at provider cost",
+ "Buy 1,000-page packs or $1 in premium credits at $1 each",
"Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter",
"Priority support on Discord",
],
From 30d06affdc7deba50fbe00a38ca2dd4ae564394d Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sat, 2 May 2026 23:40:44 -0700
Subject: [PATCH 08/12] chore: bumped version to 0.0.20
---
VERSION | 2 +-
surfsense_backend/pyproject.toml | 2 +-
surfsense_backend/uv.lock | 2 +-
surfsense_browser_extension/package.json | 2 +-
surfsense_desktop/package.json | 2 +-
surfsense_web/package.json | 2 +-
6 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/VERSION b/VERSION
index 44517d518..fe04e7f67 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.0.19
+0.0.20
diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml
index cd683e2e1..b9c389734 100644
--- a/surfsense_backend/pyproject.toml
+++ b/surfsense_backend/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "surf-new-backend"
-version = "0.0.19"
+version = "0.0.20"
description = "SurfSense Backend"
requires-python = ">=3.12"
dependencies = [
diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock
index efe670d05..46dd0b613 100644
--- a/surfsense_backend/uv.lock
+++ b/surfsense_backend/uv.lock
@@ -7947,7 +7947,7 @@ wheels = [
[[package]]
name = "surf-new-backend"
-version = "0.0.19"
+version = "0.0.20"
source = { editable = "." }
dependencies = [
{ name = "alembic" },
diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json
index 146dd177e..1ffc4dd87 100644
--- a/surfsense_browser_extension/package.json
+++ b/surfsense_browser_extension/package.json
@@ -1,7 +1,7 @@
{
"name": "surfsense_browser_extension",
"displayName": "Surfsense Browser Extension",
- "version": "0.0.19",
+ "version": "0.0.20",
"description": "Extension to collect Browsing History for SurfSense.",
"author": "https://github.com/MODSetter",
"engines": {
diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json
index e2712d8ea..960267e16 100644
--- a/surfsense_desktop/package.json
+++ b/surfsense_desktop/package.json
@@ -1,6 +1,6 @@
{
"name": "surfsense-desktop",
- "version": "0.0.19",
+ "version": "0.0.20",
"description": "SurfSense Desktop App",
"main": "dist/main.js",
"scripts": {
diff --git a/surfsense_web/package.json b/surfsense_web/package.json
index 41175daeb..399544019 100644
--- a/surfsense_web/package.json
+++ b/surfsense_web/package.json
@@ -1,6 +1,6 @@
{
"name": "surfsense_web",
- "version": "0.0.19",
+ "version": "0.0.20",
"private": true,
"description": "SurfSense Frontend",
"scripts": {
From cab1dd6fb26808f0f5b948c6fba53558e9b6c615 Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sun, 3 May 2026 00:39:27 -0700
Subject: [PATCH 09/12] fix: docker issues
---
surfsense_backend/Dockerfile | 39 ++++++++++++++++++++++----------
surfsense_backend/pyproject.toml | 2 +-
surfsense_backend/uv.lock | 2 +-
3 files changed, 29 insertions(+), 14 deletions(-)
diff --git a/surfsense_backend/Dockerfile b/surfsense_backend/Dockerfile
index 1222b36b6..73d5819b9 100644
--- a/surfsense_backend/Dockerfile
+++ b/surfsense_backend/Dockerfile
@@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
COPY pyproject.toml .
COPY uv.lock .
-# Install PyTorch based on architecture
-RUN if [ "$(uname -m)" = "x86_64" ]; then \
- pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \
- else \
- pip install --no-cache-dir torch torchvision torchaudio; \
- fi
-
-# Install python dependencies
+# Install all Python dependencies from uv.lock for deterministic builds.
+#
+# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
+# which lets prod silently drift to newer upstream versions on every rebuild
+# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
+# Exporting the lock to requirements.txt and feeding it to `uv pip install`
+# pins every transitive package to the exact version captured in uv.lock.
+#
+# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
+# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
+# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
+# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
+# downloads that the lock-based install immediately replaced. If a specific
+# CUDA version is needed (driver compatibility, etc.), wire it through
+# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
RUN pip install --no-cache-dir uv && \
- uv pip install --system --no-cache-dir -e .
+ uv export --frozen --no-dev --no-hashes --no-emit-project \
+ --format requirements-txt -o /tmp/requirements.txt && \
+ uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
+ rm /tmp/requirements.txt
# Set SSL environment variables dynamically
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
@@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
# Pre-download Docling models
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
-# Install Playwright browsers for web scraping if needed
-RUN pip install playwright && \
- playwright install chromium --with-deps
+# Install Playwright browsers for web scraping (the playwright package itself
+# is already installed via uv.lock above)
+RUN playwright install chromium --with-deps
# Copy source code
COPY . .
+# Install the project itself in editable mode. Dependencies were already
+# installed deterministically from uv.lock above, so --no-deps prevents any
+# re-resolution that could pull newer versions.
+RUN uv pip install --system --no-cache-dir --no-deps -e .
+
# Copy and set permissions for entrypoint script
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml
index b9c389734..b2bf17305 100644
--- a/surfsense_backend/pyproject.toml
+++ b/surfsense_backend/pyproject.toml
@@ -71,11 +71,11 @@ dependencies = [
"langchain>=1.2.13",
"langgraph>=1.1.3",
"langchain-community>=0.4.1",
- "deepagents>=0.4.12",
"stripe>=15.0.0",
"azure-ai-documentintelligence>=1.0.2",
"litellm>=1.83.7",
"langchain-litellm>=0.6.4",
+ "deepagents>=0.4.12,<0.5",
]
[dependency-groups]
diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock
index 46dd0b613..ffc977262 100644
--- a/surfsense_backend/uv.lock
+++ b/surfsense_backend/uv.lock
@@ -8045,7 +8045,7 @@ requires-dist = [
{ name = "composio", specifier = ">=0.10.9" },
{ name = "datasets", specifier = ">=2.21.0" },
{ name = "daytona", specifier = ">=0.146.0" },
- { name = "deepagents", specifier = ">=0.4.12" },
+ { name = "deepagents", specifier = ">=0.4.12,<0.5" },
{ name = "discord-py", specifier = ">=2.5.2" },
{ name = "docling", specifier = ">=2.15.0" },
{ name = "elasticsearch", specifier = ">=9.1.1" },
From a34f1fb25c0cf9eb8de4d57080be880b1acb2e66 Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sun, 3 May 2026 06:03:40 -0700
Subject: [PATCH 10/12] feat: implement agent caches and fix invalid prompt
cache configs
- Added a new function `_warm_agent_jit_caches` to pre-warm agent caches at startup, reducing cold invocation costs.
- Updated the `SurfSenseContextSchema` to include per-invocation fields for better state management during agent execution.
- Introduced caching mechanisms in various tools to ensure fresh database sessions are used, improving performance and reliability.
- Enhanced middleware to support new context features and improve error handling during connector and document type discovery.
---
surfsense_backend/.env.example | 27 +
.../app/agents/new_chat/agent_cache.py | 357 ++++++++++
.../app/agents/new_chat/chat_deepagent.py | 142 +++-
.../app/agents/new_chat/context.py | 63 +-
.../app/agents/new_chat/feature_flags.py | 42 ++
.../agents/new_chat/middleware/__init__.py | 4 +
.../new_chat/middleware/flatten_system.py | 233 +++++++
.../new_chat/middleware/knowledge_search.py | 42 +-
.../app/agents/new_chat/prompt_caching.py | 36 +-
.../new_chat/tools/confluence/create_page.py | 293 +++++----
.../new_chat/tools/confluence/delete_page.py | 254 ++++----
.../new_chat/tools/confluence/update_page.py | 296 +++++----
.../new_chat/tools/connected_accounts.py | 80 ++-
.../new_chat/tools/discord/list_channels.py | 108 ++--
.../new_chat/tools/discord/read_messages.py | 112 ++--
.../new_chat/tools/discord/send_message.py | 129 ++--
.../new_chat/tools/dropbox/create_file.py | 336 +++++-----
.../new_chat/tools/dropbox/trash_file.py | 344 +++++-----
.../new_chat/tools/gmail/create_draft.py | 493 +++++++-------
.../agents/new_chat/tools/gmail/read_email.py | 180 +++---
.../new_chat/tools/gmail/search_emails.py | 166 ++---
.../agents/new_chat/tools/gmail/send_email.py | 497 +++++++-------
.../new_chat/tools/gmail/trash_email.py | 477 +++++++-------
.../new_chat/tools/gmail/update_draft.py | 611 +++++++++---------
.../tools/google_calendar/create_event.py | 542 ++++++++--------
.../tools/google_calendar/delete_event.py | 472 +++++++-------
.../tools/google_calendar/search_events.py | 160 +++--
.../tools/google_calendar/update_event.py | 581 +++++++++--------
.../tools/google_drive/create_file.py | 447 +++++++------
.../new_chat/tools/google_drive/trash_file.py | 399 ++++++------
.../new_chat/tools/jira/create_issue.py | 310 +++++----
.../new_chat/tools/jira/delete_issue.py | 249 +++----
.../new_chat/tools/jira/update_issue.py | 315 +++++----
.../new_chat/tools/linear/create_issue.py | 310 ++++-----
.../new_chat/tools/linear/delete_issue.py | 279 ++++----
.../new_chat/tools/linear/update_issue.py | 339 +++++-----
.../new_chat/tools/luma/create_event.py | 159 +++--
.../agents/new_chat/tools/luma/list_events.py | 142 ++--
.../agents/new_chat/tools/luma/read_event.py | 116 ++--
.../new_chat/tools/notion/create_page.py | 284 ++++----
.../new_chat/tools/notion/delete_page.py | 319 ++++-----
.../new_chat/tools/notion/update_page.py | 297 +++++----
.../new_chat/tools/onedrive/create_file.py | 320 ++++-----
.../new_chat/tools/onedrive/trash_file.py | 350 +++++-----
.../app/agents/new_chat/tools/registry.py | 88 ++-
.../new_chat/tools/search_surfsense_docs.py | 22 +-
.../new_chat/tools/teams/list_channels.py | 120 ++--
.../new_chat/tools/teams/read_messages.py | 120 ++--
.../new_chat/tools/teams/send_message.py | 133 ++--
.../agents/new_chat/tools/update_memory.py | 112 +++-
surfsense_backend/app/app.py | 141 ++++
.../app/services/connector_service.py | 190 +++++-
.../app/tasks/chat/stream_new_chat.py | 322 ++++++---
.../unit/agents/new_chat/test_agent_cache.py | 268 ++++++++
.../agents/new_chat/test_feature_flags.py | 7 +
.../agents/new_chat/test_flatten_system.py | 344 ++++++++++
.../agents/new_chat/test_prompt_caching.py | 28 +-
.../unit/middleware/test_knowledge_search.py | 187 ++++++
.../unit/test_stream_new_chat_contract.py | 60 ++
.../components/pricing/pricing-section.tsx | 4 +-
60 files changed, 8477 insertions(+), 5381 deletions(-)
create mode 100644 surfsense_backend/app/agents/new_chat/agent_cache.py
create mode 100644 surfsense_backend/app/agents/new_chat/middleware/flatten_system.py
create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py
diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example
index 1b1478ae6..86c1b326f 100644
--- a/surfsense_backend/.env.example
+++ b/surfsense_backend/.env.example
@@ -324,3 +324,30 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
+
+# -----------------------------------------------------------------------------
+# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
+# -----------------------------------------------------------------------------
+# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
+# on a cold turn) is reused across subsequent turns on the same thread,
+# collapsing it to a microsecond hash lookup. All connector tools acquire
+# their own short-lived DB session per call (Phase 2 refactor) so a cached
+# closure is safe to share across requests. Flip OFF only as a last-resort
+# rollback if you suspect cache-related staleness.
+# SURFSENSE_ENABLE_AGENT_CACHE=true
+
+# Cache capacity (max number of compiled-agent entries kept in memory)
+# and TTL per entry (seconds). Working set is typically one entry per
+# active thread on this replica; tune up for very large deployments.
+# SURFSENSE_AGENT_CACHE_MAXSIZE=256
+# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
+
+# -----------------------------------------------------------------------------
+# Connector discovery TTL cache (Phase 1.4 perf optimization)
+# -----------------------------------------------------------------------------
+# Caches the per-search-space "available connectors" + "available document
+# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
+# ORM event listeners auto-invalidate on connector / document inserts,
+# updates and deletes — the TTL only bounds staleness for bulk-import
+# paths that bypass the ORM. Set to 0 to disable the cache.
+# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30
diff --git a/surfsense_backend/app/agents/new_chat/agent_cache.py b/surfsense_backend/app/agents/new_chat/agent_cache.py
new file mode 100644
index 000000000..fa8e6fb72
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/agent_cache.py
@@ -0,0 +1,357 @@
+"""TTL-LRU cache for compiled SurfSense deep agents.
+
+Why this exists
+---------------
+
+``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
+turn:
+
+1. Discover connectors & document types from Postgres (~50-200ms)
+2. Build the tool list (built-in + MCP) (~200ms-1.7s)
+3. Compose the system prompt
+4. Construct ~15 middleware instances (CPU)
+5. Eagerly compile the general-purpose subagent
+ (``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
+ which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure
+ CPU work)
+6. Compile the outer LangGraph
+
+For a single thread, all six steps produce the SAME object on every turn
+unless the user has changed their LLM config, toggled a feature flag,
+added a connector, etc. The right answer is to compile ONCE per
+"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
+every subsequent turn on the same thread.
+
+Why a per-thread key (not a global pool)
+----------------------------------------
+
+Most middleware in the SurfSense stack captures per-thread state in
+``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
+``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
+would silently leak state across users and threads. Keying the cache on
+``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
+turns on the same thread without changing any middleware's behavior.
+
+Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
+(read via ``runtime.context``) so the cache can collapse to a single
+``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
+then, per-thread keying is the only safe option.
+
+Cache shape
+-----------
+
+* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
+ minutes — matches a typical chat session). ``maxsize`` (default 256)
+ caps memory; LRU evicts least-recently-used on overflow.
+* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
+ cold misses on the same key wait for the first build instead of
+ building N times.
+* Process-local: this is an in-memory cache. Multi-replica deployments
+ pay the build cost once per replica per key. That's fine; the working
+ set per replica is small (one entry per active thread on that replica).
+
+Telemetry
+---------
+
+Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
+
+ * ``hit`` — cache hit, microseconds-fast
+ * ``miss`` — first build for this key, includes build duration
+ * ``stale`` — entry was found but expired; rebuilt
+ * ``evict`` — LRU eviction (size-limited)
+ * ``size`` — current cache occupancy at lookup time
+"""
+
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import logging
+import os
+import time
+from collections import OrderedDict
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass
+from typing import Any
+
+from app.utils.perf import get_perf_logger
+
+logger = logging.getLogger(__name__)
+_perf_log = get_perf_logger()
+
+
+# ---------------------------------------------------------------------------
+# Public API: signature helpers (cache key components)
+# ---------------------------------------------------------------------------
+
+
+def stable_hash(*parts: Any) -> str:
+ """Compute a deterministic SHA1 of the str repr of ``parts``.
+
+ Used for cache key components that need a fixed-width representation
+ (system prompt, tool list, etc.). SHA1 is fine here — this is not a
+ security boundary, just a content fingerprint.
+ """
+ h = hashlib.sha1(usedforsecurity=False)
+ for p in parts:
+ h.update(repr(p).encode("utf-8", errors="replace"))
+ h.update(b"\x1f") # ASCII unit separator between parts
+ return h.hexdigest()
+
+
+def tools_signature(
+ tools: list[Any] | tuple[Any, ...],
+ *,
+ available_connectors: list[str] | None,
+ available_document_types: list[str] | None,
+) -> str:
+ """Hash the bound-tool surface for cache-key purposes.
+
+ The signature changes whenever:
+
+ * A tool is added or removed from the bound list (built-in toggles,
+ MCP tools loaded for the user changes, gating rules flip, etc.).
+ * The available connectors / document types for the search space
+ change (new connector added, last connector removed, new document
+ type indexed). Because :func:`get_connector_gated_tools` derives
+ ``modified_disabled_tools`` from ``available_connectors``, the
+ tool surface is technically already covered — but we hash the
+ connector list separately so an empty-list "no tools changed"
+ situation still rotates the key when, say, the user re-adds a
+ connector that gates a tool we were already not exposing.
+
+ Stays stable across:
+
+ * Process restarts (tool names + descriptions are static).
+ * Different replicas (everyone gets the same hash for the same
+ inputs).
+ """
+ tool_descriptors = sorted(
+ (getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
+ )
+ connectors = sorted(available_connectors or [])
+ doc_types = sorted(available_document_types or [])
+ return stable_hash(tool_descriptors, connectors, doc_types)
+
+
+def flags_signature(flags: Any) -> str:
+ """Hash the resolved :class:`AgentFeatureFlags` dataclass.
+
+ Frozen dataclasses are deterministically reprable, so a SHA1 of their
+ repr is a stable fingerprint. Restart safe (flags are read once at
+ process boot).
+ """
+ return stable_hash(repr(flags))
+
+
+def system_prompt_hash(system_prompt: str) -> str:
+ """Hash a system prompt string. Cheap, ~30µs for typical prompts."""
+ return hashlib.sha1(
+ system_prompt.encode("utf-8", errors="replace"),
+ usedforsecurity=False,
+ ).hexdigest()
+
+
+# ---------------------------------------------------------------------------
+# Cache implementation
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _Entry:
+ value: Any
+ created_at: float
+ last_used_at: float
+
+
+class _AgentCache:
+ """In-process TTL-LRU cache with per-key in-flight de-duplication.
+
+ NOT THREAD-SAFE in the multithreading sense — designed for a single
+ asyncio event loop. Uvicorn runs one event loop per worker process,
+ so this is fine; multi-worker deployments simply each maintain their
+ own cache.
+ """
+
+ def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
+ self._maxsize = maxsize
+ self._ttl = ttl_seconds
+ self._entries: OrderedDict[str, _Entry] = OrderedDict()
+ # One lock per key — guards "build" so concurrent cold misses on
+ # the same key wait for the first build instead of all racing.
+ self._locks: dict[str, asyncio.Lock] = {}
+
+ def _now(self) -> float:
+ return time.monotonic()
+
+ def _is_fresh(self, entry: _Entry) -> bool:
+ return (self._now() - entry.created_at) < self._ttl
+
+ def _evict_if_full(self) -> None:
+ while len(self._entries) >= self._maxsize:
+ evicted_key, _ = self._entries.popitem(last=False)
+ self._locks.pop(evicted_key, None)
+ _perf_log.info(
+ "[agent_cache] evict key=%s reason=lru size=%d",
+ _short(evicted_key),
+ len(self._entries),
+ )
+
+ def _touch(self, key: str, entry: _Entry) -> None:
+ entry.last_used_at = self._now()
+ self._entries.move_to_end(key, last=True)
+
+ async def get_or_build(
+ self,
+ key: str,
+ *,
+ builder: Callable[[], Awaitable[Any]],
+ ) -> Any:
+ """Return the cached value for ``key`` or call ``builder()`` to make it.
+
+ ``builder`` MUST be idempotent — concurrent cold misses on the
+ same key collapse to a single ``builder()`` call (the others
+ wait on the in-flight lock and observe the populated entry on
+ wake).
+ """
+ # Fast path: hot hit.
+ entry = self._entries.get(key)
+ if entry is not None and self._is_fresh(entry):
+ self._touch(key, entry)
+ _perf_log.info(
+ "[agent_cache] hit key=%s age=%.1fs size=%d",
+ _short(key),
+ self._now() - entry.created_at,
+ len(self._entries),
+ )
+ return entry.value
+
+ # Stale entry — drop it; rebuild below.
+ if entry is not None and not self._is_fresh(entry):
+ _perf_log.info(
+ "[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
+ _short(key),
+ self._now() - entry.created_at,
+ self._ttl,
+ )
+ self._entries.pop(key, None)
+
+ # Slow path: serialize concurrent misses for the same key.
+ lock = self._locks.setdefault(key, asyncio.Lock())
+ async with lock:
+ # Double-check after acquiring the lock — another waiter may
+ # have populated the entry while we slept.
+ entry = self._entries.get(key)
+ if entry is not None and self._is_fresh(entry):
+ self._touch(key, entry)
+ _perf_log.info(
+ "[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
+ _short(key),
+ self._now() - entry.created_at,
+ len(self._entries),
+ )
+ return entry.value
+
+ t0 = time.perf_counter()
+ try:
+ value = await builder()
+ except BaseException:
+ # Don't cache failed builds; let the next caller retry.
+ _perf_log.warning(
+ "[agent_cache] build_failed key=%s elapsed=%.3fs",
+ _short(key),
+ time.perf_counter() - t0,
+ )
+ raise
+ elapsed = time.perf_counter() - t0
+
+ # Insert + evict.
+ self._evict_if_full()
+ now = self._now()
+ self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
+ self._entries.move_to_end(key, last=True)
+ _perf_log.info(
+ "[agent_cache] miss key=%s build=%.3fs size=%d",
+ _short(key),
+ elapsed,
+ len(self._entries),
+ )
+ return value
+
+ def invalidate(self, key: str) -> bool:
+ """Drop a single entry; return True if anything was removed."""
+ removed = self._entries.pop(key, None) is not None
+ self._locks.pop(key, None)
+ if removed:
+ _perf_log.info(
+ "[agent_cache] invalidate key=%s size=%d",
+ _short(key),
+ len(self._entries),
+ )
+ return removed
+
+ def invalidate_prefix(self, prefix: str) -> int:
+ """Drop every entry whose key starts with ``prefix``. Returns count."""
+ keys = [k for k in self._entries if k.startswith(prefix)]
+ for k in keys:
+ self._entries.pop(k, None)
+ self._locks.pop(k, None)
+ if keys:
+ _perf_log.info(
+ "[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
+ _short(prefix),
+ len(keys),
+ len(self._entries),
+ )
+ return len(keys)
+
+ def clear(self) -> None:
+ n = len(self._entries)
+ self._entries.clear()
+ self._locks.clear()
+ if n:
+ _perf_log.info("[agent_cache] clear removed=%d", n)
+
+ def stats(self) -> dict[str, Any]:
+ return {
+ "size": len(self._entries),
+ "maxsize": self._maxsize,
+ "ttl_seconds": self._ttl,
+ }
+
+
+def _short(key: str, n: int = 16) -> str:
+ """Truncate keys for log lines so they don't blow up log volume."""
+ return key if len(key) <= n else f"{key[:n]}..."
+
+
+# ---------------------------------------------------------------------------
+# Module-level singleton
+# ---------------------------------------------------------------------------
+
+_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
+_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
+
+_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
+
+
+def get_cache() -> _AgentCache:
+ """Return the process-wide compiled-agent cache singleton."""
+ return _cache
+
+
+def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
+ """Replace the singleton with a fresh cache. Tests only."""
+ global _cache
+ _cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
+ return _cache
+
+
+__all__ = [
+ "flags_signature",
+ "get_cache",
+ "reload_for_tests",
+ "stable_hash",
+ "system_prompt_hash",
+ "tools_signature",
+]
diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py
index c0e9a3b96..36739adae 100644
--- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py
+++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py
@@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
+from app.agents.new_chat.agent_cache import (
+ flags_signature,
+ get_cache,
+ stable_hash,
+ system_prompt_hash,
+ tools_signature,
+)
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver
@@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
DoomLoopMiddleware,
FileIntentMiddleware,
+ FlattenSystemMessageMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
@@ -330,23 +338,39 @@ async def create_surfsense_deep_agent(
else None,
)
- # Discover available connectors and document types for this search space
+ # Discover available connectors and document types for this search space.
+ #
+ # NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
+ # ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
+ # SQLAlchemy explicitly forbids concurrent operations on the same session
+ # ("This session is provisioning a new connection; concurrent operations
+ # are not permitted on the same session"). The Phase 1.4 in-process TTL
+ # cache in ``connector_service`` already collapses the warm path to a
+ # near-zero pair of dict lookups, so sequential awaits cost nothing in
+ # the common case while remaining correct on cold cache misses.
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
- connector_types = await connector_service.get_available_connectors(
- search_space_id
- )
- if connector_types:
- available_connectors = _map_connectors_to_searchable_types(connector_types)
+ try:
+ connector_types_result = await connector_service.get_available_connectors(
+ search_space_id
+ )
+ if connector_types_result:
+ available_connectors = _map_connectors_to_searchable_types(
+ connector_types_result
+ )
+ except Exception as e:
+ logging.warning("Failed to discover available connectors: %s", e)
- available_document_types = await connector_service.get_available_document_types(
- search_space_id
- )
-
- except Exception as e:
+ try:
+ available_document_types = (
+ await connector_service.get_available_document_types(search_space_id)
+ )
+ except Exception as e:
+ logging.warning("Failed to discover available document types: %s", e)
+ except Exception as e: # pragma: no cover - defensive outer guard
logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
@@ -469,29 +493,77 @@ async def create_surfsense_deep_agent(
# entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive.
+ #
+ # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
+ # on every per-request value that any middleware in the stack closes
+ # over in ``__init__`` — drop one and you risk leaking state across
+ # threads. Hits collapse this whole block to a microsecond lookup;
+ # misses pay the original CPU cost AND populate the cache.
+ config_id = agent_config.config_id if agent_config is not None else None
+
+ async def _build_agent() -> Any:
+ return await asyncio.to_thread(
+ _build_compiled_agent_blocking,
+ llm=llm,
+ tools=tools,
+ final_system_prompt=final_system_prompt,
+ backend_resolver=backend_resolver,
+ filesystem_mode=filesystem_selection.mode,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ thread_id=thread_id,
+ visibility=visibility,
+ anon_session_id=anon_session_id,
+ available_connectors=available_connectors,
+ available_document_types=available_document_types,
+ # ``mentioned_document_ids`` is consumed by
+ # ``KnowledgePriorityMiddleware`` per turn via
+ # ``runtime.context`` (Phase 1.5). We still pass the
+ # caller-provided list here for the legacy fallback path
+ # (cache disabled / context not propagated) — the middleware
+ # drains its own copy after the first read so a cached graph
+ # never replays stale mentions.
+ mentioned_document_ids=mentioned_document_ids,
+ max_input_tokens=_max_input_tokens,
+ flags=_flags,
+ checkpointer=checkpointer,
+ )
+
_t0 = time.perf_counter()
- agent = await asyncio.to_thread(
- _build_compiled_agent_blocking,
- llm=llm,
- tools=tools,
- final_system_prompt=final_system_prompt,
- backend_resolver=backend_resolver,
- filesystem_mode=filesystem_selection.mode,
- search_space_id=search_space_id,
- user_id=user_id,
- thread_id=thread_id,
- visibility=visibility,
- anon_session_id=anon_session_id,
- available_connectors=available_connectors,
- available_document_types=available_document_types,
- mentioned_document_ids=mentioned_document_ids,
- max_input_tokens=_max_input_tokens,
- flags=_flags,
- checkpointer=checkpointer,
- )
+ if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
+ # Cache key components — order matters only for human readability;
+ # the resulting hash is what's stored. Every component must
+ # rotate on a real shape change AND stay stable across identical
+ # invocations.
+ cache_key = stable_hash(
+ "v1", # schema version of the key — bump if components change
+ config_id,
+ thread_id,
+ user_id,
+ search_space_id,
+ visibility,
+ filesystem_selection.mode,
+ anon_session_id,
+ tools_signature(
+ tools,
+ available_connectors=available_connectors,
+ available_document_types=available_document_types,
+ ),
+ flags_signature(_flags),
+ system_prompt_hash(final_system_prompt),
+ _max_input_tokens,
+ # ``mentioned_document_ids`` deliberately omitted — middleware
+ # reads it from ``runtime.context`` (Phase 1.5).
+ )
+ agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
+ else:
+ agent = await _build_agent()
_perf_log.info(
- "[create_agent] Middleware stack + graph compiled in %.3fs",
+ "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
time.perf_counter() - _t0,
+ "on"
+ if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
+ else "off",
)
_perf_log.info(
@@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking(
noop_mw,
retry_mw,
fallback_mw,
+ # Coalesce a multi-text-block system message into one block
+ # immediately before the model call. Sits innermost on the
+ # system-message-mutation chain so it observes every appender
+ # (todo / filesystem / skills / subagents …) and prevents
+ # OpenRouter→Anthropic from redistributing ``cache_control``
+ # across N blocks and tripping Anthropic's 4-breakpoint cap.
+ # See ``middleware/flatten_system.py`` for full rationale.
+ FlattenSystemMessageMiddleware(),
# Tool-call repair must run after model emits but before
# permission / dedup / doom-loop interpret the calls.
repair_mw,
diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py
index c1fe45aaa..d720b524b 100644
--- a/surfsense_backend/app/agents/new_chat/context.py
+++ b/surfsense_backend/app/agents/new_chat/context.py
@@ -1,10 +1,25 @@
"""
Context schema definitions for SurfSense agents.
-This module defines the custom state schema used by the SurfSense deep agent.
+This module defines the per-invocation context object passed to the SurfSense
+deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
+
+The agent's compiled graph is the same across invocations (and cached by
+``agent_cache``), so anything that varies per turn — the user mentions a
+specific document, the front-end issues a unique ``request_id``, etc. —
+MUST live on this context object instead of being captured into a
+middleware ``__init__`` closure. Middlewares read fields back via
+``runtime.context.``; tools read them via ``runtime.context``.
+
+This object is read inside both ``KnowledgePriorityMiddleware`` (for
+``mentioned_document_ids``) and any future middleware that needs
+per-request state without invalidating the compiled-agent cache.
"""
-from typing import NotRequired, TypedDict
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import TypedDict
class FileOperationContractState(TypedDict):
@@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
turn_id: str
-class SurfSenseContextSchema(TypedDict):
+@dataclass
+class SurfSenseContextSchema:
"""
- Custom state schema for the SurfSense deep agent.
+ Per-invocation context for the SurfSense deep agent.
- This extends the default agent state with custom fields.
- The default state already includes:
- - messages: Conversation history
- - todos: Task list from TodoListMiddleware
- - files: Virtual filesystem from FilesystemMiddleware
+ Defaults are chosen so the dataclass can be safely default-constructed
+ (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
+ context is supplied — see ``langgraph.runtime.Runtime``). All fields
+ are optional; consumers must None-check before reading.
- We're adding fields needed for knowledge base search:
- - search_space_id: The user's search space ID
- - db_session: Database session (injected at runtime)
- - connector_service: Connector service instance (injected at runtime)
+ Phase 1.5 fields:
+ search_space_id: Search space the request is scoped to.
+ mentioned_document_ids: KB documents the user @-mentioned this turn.
+ Read by ``KnowledgePriorityMiddleware`` to seed its priority
+ list. Stays out of the compiled-agent cache key — that's the
+ whole point of putting it here.
+ file_operation_contract: One-shot file operation contract emitted
+ by ``FileIntentMiddleware`` for the upcoming turn.
+ turn_id / request_id: Correlation IDs surfaced by the streaming
+ task; populated for telemetry.
+
+ Phase 2 will extend with: thread_id, user_id, visibility,
+ filesystem_mode, anon_session_id, available_connectors,
+ available_document_types, created_by_id (everything currently captured
+ by middleware ``__init__`` closures).
"""
- search_space_id: int
- file_operation_contract: NotRequired[FileOperationContractState]
- turn_id: NotRequired[str]
- request_id: NotRequired[str]
- # These are runtime-injected and won't be serialized
- # db_session and connector_service are passed when invoking the agent
+ search_space_id: int | None = None
+ mentioned_document_ids: list[int] = field(default_factory=list)
+ file_operation_contract: FileOperationContractState | None = None
+ turn_id: str | None = None
+ request_id: str | None = None
diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py
index 5007d89a5..1f5a08ec6 100644
--- a/surfsense_backend/app/agents/new_chat/feature_flags.py
+++ b/surfsense_backend/app/agents/new_chat/feature_flags.py
@@ -103,6 +103,41 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False
+ # Performance — compiled-agent cache (Phase 1 + Phase 2).
+ # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
+ # graph if the cache key matches (LLM config + thread + tool surface +
+ # flags + system prompt + filesystem mode). Cuts per-turn agent-build
+ # wall clock from ~4-5s to <50µs on cache hits.
+ #
+ # SAFETY (Phase 2 unblocked this default-on):
+ # All connector mutation tools (``tools/notion``, ``tools/gmail``,
+ # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
+ # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
+ # ``tools/teams``, ``tools/luma``, ``connected_accounts``,
+ # ``update_memory``, ``search_surfsense_docs``) now acquire fresh
+ # short-lived ``AsyncSession`` instances per call via
+ # :data:`async_session_maker`. The factory still accepts ``db_session``
+ # for registry compatibility but ``del``'s it immediately — see any
+ # of those files' factory docstrings for the rationale. The ``llm``
+ # closure is per-(provider, model, config_id) which is already in
+ # the cache key, so the LLM is safe to share across cached hits of
+ # the same key. The KB priority middleware reads
+ # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
+ # not its constructor closure, so the same compiled agent serves
+ # turns with different mention lists correctly.
+ #
+ # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
+ # environment if a regression surfaces. The path is exercised by
+ # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
+ enable_agent_cache: bool = True
+ # Phase 1 (deferred — measure first): pre-build & share the
+ # general-purpose subagent ``CompiledSubAgent`` across cold-cache
+ # misses. Only helps when the outer cache MISSES (cache hits already
+ # reuse the entire SubAgentMiddleware-compiled graph). Off by default
+ # until we have data showing cold misses are frequent enough to
+ # justify the extra global state.
+ enable_agent_cache_share_gp_subagent: bool = False
+
@classmethod
def from_env(cls) -> AgentFeatureFlags:
"""Read flags from environment.
@@ -137,6 +172,8 @@ class AgentFeatureFlags:
enable_stream_parity_v2=False,
enable_plugin_loader=False,
enable_otel=False,
+ enable_agent_cache=False,
+ enable_agent_cache_share_gp_subagent=False,
)
return cls(
@@ -179,6 +216,11 @@ class AgentFeatureFlags:
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
+ # Performance
+ enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
+ enable_agent_cache_share_gp_subagent=_env_bool(
+ "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
+ ),
)
def any_new_middleware_enabled(self) -> bool:
diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py
index 094c102f8..6742bd8de 100644
--- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py
+++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py
@@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
+from app.agents.new_chat.middleware.flatten_system import (
+ FlattenSystemMessageMiddleware,
+)
from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
@@ -61,6 +64,7 @@ __all__ = [
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"FileIntentMiddleware",
+ "FlattenSystemMessageMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",
diff --git a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py
new file mode 100644
index 000000000..29cd57aa0
--- /dev/null
+++ b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py
@@ -0,0 +1,233 @@
+r"""Coalesce multi-block system messages into a single text block.
+
+Several middlewares in our deepagent stack each call
+``append_to_system_message`` on the way down to the model
+(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
+``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
+request reaches the LLM, the system message has 5+ separate text blocks.
+
+Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
+request**, and we configure 2 injection points
+(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
+the prepended ``request.system_message``, this middleware is the
+defensive partner: it guarantees that "the system block" is *one*
+content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
+OpenRouter→Anthropic transformer can never multiply our budget into
+several breakpoints by spreading ``cache_control`` across multiple
+text blocks of a multi-block system content.
+
+Without flattening we used to see::
+
+ OpenrouterException - {"error":{"message":"Provider returned error",
+ "code":400,"metadata":{"raw":"...A maximum of 4 blocks with
+ cache_control may be provided. Found 5."}}}
+
+(Same error class documented in
+https://github.com/BerriAI/litellm/issues/15696 and
+https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
+in PR #15395 covers the litellm transformer but does not protect us
+when the OpenRouter SaaS itself does the redistribution.)
+
+A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
+the first injection point from ``role: system`` to ``index: 0``)
+neutralises the *primary* cause of the same 400 — multiple
+``SystemMessage``\ s injected by ``before_agent`` middlewares
+(priority/tree/memory/file-intent/anonymous-doc) accumulating across
+turns, each tagged with ``cache_control`` by the ``role: system``
+matcher. This middleware remains useful as defence-in-depth against
+the multi-block redistribution path.
+
+Placement: innermost on the system-message-mutation chain, after every
+appender (``todo``/``filesystem``/``skills``/``subagents``) and after
+summarization, but before ``noop``/``retry``/``fallback`` so each retry
+attempt sees a flattened payload. See ``chat_deepagent.py``.
+
+Idempotent: a string-content system message is left untouched. A list
+that contains anything other than plain text blocks (e.g. an image) is
+also left untouched — those are rare on system messages and we'd lose
+the non-text payload by joining.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Awaitable, Callable
+from typing import Any
+
+from langchain.agents.middleware.types import (
+ AgentMiddleware,
+ AgentState,
+ ContextT,
+ ModelRequest,
+ ModelResponse,
+ ResponseT,
+)
+from langchain_core.messages import SystemMessage
+
+logger = logging.getLogger(__name__)
+
+
+def _flatten_text_blocks(content: list[Any]) -> str | None:
+ """Return joined text if every block is a plain ``{"type": "text"}``.
+
+ Returns ``None`` when the list contains anything that isn't a text
+ block we can safely concatenate (image, audio, file, non-standard
+ blocks, dicts with extra non-cache_control fields). The caller
+ leaves the original content untouched in that case rather than
+ silently dropping payload.
+
+ ``cache_control`` on individual blocks is intentionally discarded —
+ the whole point of flattening is to let LiteLLM's
+ ``cache_control_injection_points`` re-place a single breakpoint on
+ the resulting one-block system content.
+ """
+ chunks: list[str] = []
+ for block in content:
+ if isinstance(block, str):
+ chunks.append(block)
+ continue
+ if not isinstance(block, dict):
+ return None
+ if block.get("type") != "text":
+ return None
+ text = block.get("text")
+ if not isinstance(text, str):
+ return None
+ chunks.append(text)
+ return "\n\n".join(chunks)
+
+
+def _flattened_request(
+ request: ModelRequest[ContextT],
+) -> ModelRequest[ContextT] | None:
+ """Return a request with system_message flattened, or ``None`` for no-op."""
+ sys_msg = request.system_message
+ if sys_msg is None:
+ return None
+ content = sys_msg.content
+ if not isinstance(content, list) or len(content) <= 1:
+ return None
+
+ flattened = _flatten_text_blocks(content)
+ if flattened is None:
+ return None
+
+ new_sys = SystemMessage(
+ content=flattened,
+ additional_kwargs=dict(sys_msg.additional_kwargs),
+ response_metadata=dict(sys_msg.response_metadata),
+ )
+ if sys_msg.id is not None:
+ new_sys.id = sys_msg.id
+ return request.override(system_message=new_sys)
+
+
+def _diagnostic_summary(request: ModelRequest[Any]) -> str:
+ """One-line dump of cache_control-relevant request shape.
+
+ Temporary diagnostic to prove where the ``Found N`` cache_control
+ breakpoints are coming from when Anthropic 400s. Removed once the
+ root cause is confirmed and a fix is in place.
+ """
+ sys_msg = request.system_message
+ if sys_msg is None:
+ sys_shape = "none"
+ elif isinstance(sys_msg.content, str):
+ sys_shape = f"str(len={len(sys_msg.content)})"
+ elif isinstance(sys_msg.content, list):
+ sys_shape = f"list(blocks={len(sys_msg.content)})"
+ else:
+ sys_shape = f"other({type(sys_msg.content).__name__})"
+
+ role_hist: list[str] = []
+ multi_block_msgs = 0
+ msgs_with_cc = 0
+ sys_msgs_in_history = 0
+ for m in request.messages:
+ mtype = getattr(m, "type", type(m).__name__)
+ role_hist.append(mtype)
+ if isinstance(m, SystemMessage):
+ sys_msgs_in_history += 1
+ c = getattr(m, "content", None)
+ if isinstance(c, list):
+ multi_block_msgs += 1
+ for blk in c:
+ if isinstance(blk, dict) and "cache_control" in blk:
+ msgs_with_cc += 1
+ break
+ if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
+ msgs_with_cc += 1
+
+ tools = request.tools or []
+ tools_with_cc = 0
+ for t in tools:
+ if isinstance(t, dict) and (
+ "cache_control" in t or "cache_control" in t.get("function", {})
+ ):
+ tools_with_cc += 1
+
+ return (
+ f"sys={sys_shape} msgs={len(request.messages)} "
+ f"sys_msgs_in_history={sys_msgs_in_history} "
+ f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
+ f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
+ f"roles={role_hist[-8:]}"
+ )
+
+
+class FlattenSystemMessageMiddleware(
+ AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
+):
+ """Collapse a multi-text-block system message to a single string.
+
+ Sits innermost on the system-message-mutation chain so it observes
+ every middleware's contribution. Has no other side effect — the
+ body of every block is preserved, just joined with ``"\\n\\n"``.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.tools = []
+
+ def wrap_model_call( # type: ignore[override]
+ self,
+ request: ModelRequest[ContextT],
+ handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
+ ) -> Any:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
+ flattened = _flattened_request(request)
+ if flattened is not None:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "[flatten_system] collapsed %d system blocks to one",
+ len(request.system_message.content), # type: ignore[arg-type, union-attr]
+ )
+ return handler(flattened)
+ return handler(request)
+
+ async def awrap_model_call( # type: ignore[override]
+ self,
+ request: ModelRequest[ContextT],
+ handler: Callable[
+ [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
+ ],
+ ) -> Any:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
+ flattened = _flattened_request(request)
+ if flattened is not None:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "[flatten_system] collapsed %d system blocks to one",
+ len(request.system_message.content), # type: ignore[arg-type, union-attr]
+ )
+ return await handler(flattened)
+ return await handler(request)
+
+
+__all__ = [
+ "FlattenSystemMessageMiddleware",
+ "_flatten_text_blocks",
+ "_flattened_request",
+]
diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py
index 0820e8c3e..ee5c1d182 100644
--- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py
+++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py
@@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
- del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
@@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if anon_doc:
return self._anon_priority(state, anon_doc)
- return await self._authenticated_priority(state, messages, user_text)
+ return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority(
self,
@@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
messages: Sequence[BaseMessage],
user_text: str,
+ runtime: Runtime[Any] | None = None,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
(
@@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text,
)
+ # Per-turn ``mentioned_document_ids`` flow:
+ # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
+ # streaming task supplies a fresh :class:`SurfSenseContextSchema`
+ # on every ``astream_events`` call, so this list is naturally
+ # scoped to the current turn. Allows cross-turn graph reuse via
+ # ``agent_cache``.
+ # 2. Legacy fallback (cache disabled / context not propagated): the
+ # constructor-injected ``self.mentioned_document_ids`` list. We
+ # drain it after the first read so a cached graph (no Phase 1.5
+ # wiring) doesn't keep replaying the same mentions on every
+ # turn.
+ #
+ # CRITICAL: distinguish "context absent" (legacy caller, no field at
+ # all) from "context provided but empty" (turn with no mentions).
+ # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
+ # Python, so a naive ``if ctx_mentions:`` would fall through to the
+ # legacy closure on every no-mention follow-up turn — replaying the
+ # mentions baked in by turn 1's cache-miss build. Always drain the
+ # closure once the runtime path has fired so a cached middleware
+ # instance can never resurrect stale state.
+ mention_ids: list[int] = []
+ ctx = getattr(runtime, "context", None) if runtime is not None else None
+ ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
+ if ctx_mentions is not None:
+ # Runtime path is authoritative — even an empty list means
+ # "this turn has no mentions", NOT "look at the closure".
+ mention_ids = list(ctx_mentions)
+ if self.mentioned_document_ids:
+ self.mentioned_document_ids = []
+ elif self.mentioned_document_ids:
+ mention_ids = list(self.mentioned_document_ids)
+ self.mentioned_document_ids = []
+
mentioned_results: list[dict[str, Any]] = []
- if self.mentioned_document_ids:
+ if mention_ids:
mentioned_results = await fetch_mentioned_documents(
- document_ids=self.mentioned_document_ids,
+ document_ids=mention_ids,
search_space_id=self.search_space_id,
)
- self.mentioned_document_ids = []
if is_recency:
doc_types = _resolve_search_types(
diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py
index 86bc57725..9fe47cdac 100644
--- a/surfsense_backend/app/agents/new_chat/prompt_caching.py
+++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py
@@ -1,4 +1,4 @@
-"""LiteLLM-native prompt caching configuration for SurfSense agents.
+r"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
@@ -17,8 +17,20 @@ Coverage:
We inject **two** breakpoints per request:
-- ``role: system`` — pins the SurfSense system prompt (provider variant,
- citation rules, tool catalog, KB tree, skills metadata) into the cache.
+- ``index: 0`` — pins the SurfSense system prompt at the head of the
+ request (provider variant, citation rules, tool catalog, KB tree,
+ skills metadata). The langchain agent factory always prepends
+ ``request.system_message`` at index 0 (see ``factory.py``
+ ``_execute_model_async``), so this targets exactly the main system
+ prompt regardless of how many other ``SystemMessage``\ s the
+ ``before_agent`` injectors (priority, tree, memory, file-intent,
+ anonymous-doc) have inserted into ``state["messages"]``. Using
+ ``role: system`` here would apply ``cache_control`` to **every**
+ system-role message and trip Anthropic's hard cap of 4 cache
+ breakpoints per request once the conversation accumulates enough
+ injected system messages — which surfaces as the upstream 400
+ ``A maximum of 4 blocks with cache_control may be provided. Found N``
+ via OpenRouter→Anthropic.
- ``index: -1`` — pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
@@ -51,11 +63,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-# Two-breakpoint policy: system + latest message. See module docstring for
-# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
-# use 2 here, leaving headroom for Phase-2 tool caching.
+# Two-breakpoint policy: head-of-request + latest message. See module
+# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
+# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
+#
+# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
+# ``before_agent`` middlewares (priority, tree, memory, file-intent,
+# anonymous-doc) insert ``SystemMessage`` instances into
+# ``state["messages"]`` that accumulate across turns. With
+# ``role: system`` the LiteLLM hook would tag *every* one of them with
+# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
+# always targets the langchain-prepended ``request.system_message``
+# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
+# block), giving us exactly one stable cache breakpoint.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
- {"location": "message", "role": "system"},
+ {"location": "message", "index": 0},
{"location": "message", "index": -1},
)
diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py
index 095413bdb..c56db1528 100644
--- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py
@@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
+from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """
+ Factory function to create the create_confluence_page tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_confluence_page tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_confluence_page(
title: str,
@@ -42,160 +60,163 @@ def create_create_confluence_page_tool(
"""
logger.info(f"create_confluence_page called: title='{title}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
- metadata_service = ConfluenceToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = ConfluenceToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
+ )
- if "error" in context:
- return {"status": "error", "message": context["error"]}
+ if "error" in context:
+ return {"status": "error", "message": context["error"]}
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected Confluence accounts need re-authentication.",
- "connector_type": "confluence",
- }
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected Confluence accounts need re-authentication.",
+ "connector_type": "confluence",
+ }
- result = request_approval(
- action_type="confluence_page_creation",
- tool_name="create_confluence_page",
- params={
- "title": title,
- "content": content,
- "space_id": space_id,
- "connector_id": connector_id,
- },
- context=context,
- )
+ result = request_approval(
+ action_type="confluence_page_creation",
+ tool_name="create_confluence_page",
+ params={
+ "title": title,
+ "content": content,
+ "space_id": space_id,
+ "connector_id": connector_id,
+ },
+ context=context,
+ )
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
- final_title = result.params.get("title", title)
- final_content = result.params.get("content", content) or ""
- final_space_id = result.params.get("space_id", space_id)
- final_connector_id = result.params.get("connector_id", connector_id)
+ final_title = result.params.get("title", title)
+ final_content = result.params.get("content", content) or ""
+ final_space_id = result.params.get("space_id", space_id)
+ final_connector_id = result.params.get("connector_id", connector_id)
- if not final_title or not final_title.strip():
- return {"status": "error", "message": "Page title cannot be empty."}
- if not final_space_id:
- return {"status": "error", "message": "A space must be selected."}
+ if not final_title or not final_title.strip():
+ return {"status": "error", "message": "Page title cannot be empty."}
+ if not final_space_id:
+ return {"status": "error", "message": "A space must be selected."}
- from sqlalchemy.future import select
+ from sqlalchemy.future import select
- from app.db import SearchSourceConnector, SearchSourceConnectorType
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Confluence connector found.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Confluence connector is invalid.",
- }
-
- try:
- client = ConfluenceHistoryConnector(
- session=db_session, connector_id=actual_connector_id
- )
- api_result = await client.create_page(
- space_id=final_space_id,
- title=final_title,
- body=final_content,
- )
- await client.close()
- except Exception as api_err:
- if (
- "http 403" in str(api_err).lower()
- or "status code 403" in str(api_err).lower()
- ):
- try:
- _conn = connector
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- pass
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
-
- page_id = str(api_result.get("id", ""))
- page_links = (
- api_result.get("_links", {}) if isinstance(api_result, dict) else {}
- )
- page_url = ""
- if page_links.get("base") and page_links.get("webui"):
- page_url = f"{page_links['base']}{page_links['webui']}"
-
- kb_message_suffix = ""
- try:
- from app.services.confluence import ConfluenceKBSyncService
-
- kb_service = ConfluenceKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- page_id=page_id,
- page_title=final_title,
- space_id=final_space_id,
- body_content=final_content,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Confluence connector found.",
+ }
+ actual_connector_id = connector.id
else:
- kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Confluence connector is invalid.",
+ }
- return {
- "status": "success",
- "page_id": page_id,
- "page_url": page_url,
- "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
- }
+ try:
+ client = ConfluenceHistoryConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
+ api_result = await client.create_page(
+ space_id=final_space_id,
+ title=final_title,
+ body=final_content,
+ )
+ await client.close()
+ except Exception as api_err:
+ if (
+ "http 403" in str(api_err).lower()
+ or "status code 403" in str(api_err).lower()
+ ):
+ try:
+ _conn = connector
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ page_id = str(api_result.get("id", ""))
+ page_links = (
+ api_result.get("_links", {}) if isinstance(api_result, dict) else {}
+ )
+ page_url = ""
+ if page_links.get("base") and page_links.get("webui"):
+ page_url = f"{page_links['base']}{page_links['webui']}"
+
+ kb_message_suffix = ""
+ try:
+ from app.services.confluence import ConfluenceKBSyncService
+
+ kb_service = ConfluenceKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ page_id=page_id,
+ page_title=final_title,
+ space_id=final_space_id,
+ body_content=final_content,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "page_id": page_id,
+ "page_url": page_url,
+ "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py
index 7c03c2760..d4cd5032f 100644
--- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py
@@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
+from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """
+ Factory function to create the delete_confluence_page tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_confluence_page tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_confluence_page(
page_title_or_id: str,
@@ -43,137 +61,143 @@ def create_delete_confluence_page_tool(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
- metadata_service = ConfluenceToolMetadataService(db_session)
- context = await metadata_service.get_deletion_context(
- search_space_id, user_id, page_title_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "confluence",
- }
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
-
- page_data = context["page"]
- page_id = page_data["page_id"]
- page_title = page_data.get("page_title", "")
- document_id = page_data["document_id"]
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="confluence_page_deletion",
- tool_name="delete_confluence_page",
- params={
- "page_id": page_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = ConfluenceToolMetadataService(db_session)
+ context = await metadata_service.get_deletion_context(
+ search_space_id, user_id, page_title_or_id
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Confluence connector is invalid.",
- }
- try:
- client = ConfluenceHistoryConnector(
- session=db_session, connector_id=final_connector_id
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "confluence",
+ }
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
+
+ page_data = context["page"]
+ page_id = page_data["page_id"]
+ page_title = page_data.get("page_title", "")
+ document_id = page_data["document_id"]
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ result = request_approval(
+ action_type="confluence_page_deletion",
+ tool_name="delete_confluence_page",
+ params={
+ "page_id": page_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- await client.delete_page(final_page_id)
- await client.close()
- except Exception as api_err:
- if (
- "http 403" in str(api_err).lower()
- or "status code 403" in str(api_err).lower()
- ):
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- raise
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
+ final_page_id = result.params.get("page_id", page_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this page.",
+ }
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Confluence connector is invalid.",
+ }
- message = f"Confluence page '{page_title}' deleted successfully."
- if deleted_from_kb:
- message += " Also removed from the knowledge base."
+ try:
+ client = ConfluenceHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ await client.delete_page(final_page_id)
+ await client.close()
+ except Exception as api_err:
+ if (
+ "http 403" in str(api_err).lower()
+ or "status code 403" in str(api_err).lower()
+ ):
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- return {
- "status": "success",
- "page_id": final_page_id,
- "deleted_from_kb": deleted_from_kb,
- "message": message,
- }
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+
+ message = f"Confluence page '{page_title}' deleted successfully."
+ if deleted_from_kb:
+ message += " Also removed from the knowledge base."
+
+ return {
+ "status": "success",
+ "page_id": final_page_id,
+ "deleted_from_kb": deleted_from_kb,
+ "message": message,
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py
index 791d0d8c5..51c205e00 100644
--- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py
@@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
+from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """
+ Factory function to create the update_confluence_page tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured update_confluence_page tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_confluence_page(
page_title_or_id: str,
@@ -45,164 +63,168 @@ def create_update_confluence_page_tool(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
- metadata_service = ConfluenceToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, page_title_or_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = ConfluenceToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, page_title_or_id
+ )
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "confluence",
+ }
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
+
+ page_data = context["page"]
+ page_id = page_data["page_id"]
+ current_title = page_data["page_title"]
+ current_body = page_data.get("body", "")
+ current_version = page_data.get("version", 1)
+ document_id = page_data.get("document_id")
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ result = request_approval(
+ action_type="confluence_page_update",
+ tool_name="update_confluence_page",
+ params={
+ "page_id": page_id,
+ "document_id": document_id,
+ "new_title": new_title,
+ "new_content": new_content,
+ "version": current_version,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
+ )
+
+ if result.rejected:
return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "confluence",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
- page_data = context["page"]
- page_id = page_data["page_id"]
- current_title = page_data["page_title"]
- current_body = page_data.get("body", "")
- current_version = page_data.get("version", 1)
- document_id = page_data.get("document_id")
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="confluence_page_update",
- tool_name="update_confluence_page",
- params={
- "page_id": page_id,
- "document_id": document_id,
- "new_title": new_title,
- "new_content": new_content,
- "version": current_version,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_title = result.params.get("new_title", new_title) or current_title
- final_content = result.params.get("new_content", new_content)
- if final_content is None:
- final_content = current_body
- final_version = result.params.get("version", current_version)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_document_id = result.params.get("document_id", document_id)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
+ final_page_id = result.params.get("page_id", page_id)
+ final_title = result.params.get("new_title", new_title) or current_title
+ final_content = result.params.get("new_content", new_content)
+ if final_content is None:
+ final_content = current_body
+ final_version = result.params.get("version", current_version)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Confluence connector is invalid.",
- }
+ final_document_id = result.params.get("document_id", document_id)
- try:
- client = ConfluenceHistoryConnector(
- session=db_session, connector_id=final_connector_id
- )
- api_result = await client.update_page(
- page_id=final_page_id,
- title=final_title,
- body=final_content,
- version_number=final_version + 1,
- )
- await client.close()
- except Exception as api_err:
- if (
- "http 403" in str(api_err).lower()
- or "status code 403" in str(api_err).lower()
- ):
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "error",
+ "message": "No connector found for this page.",
}
- raise
- page_links = (
- api_result.get("_links", {}) if isinstance(api_result, dict) else {}
- )
- page_url = ""
- if page_links.get("base") and page_links.get("webui"):
- page_url = f"{page_links['base']}{page_links['webui']}"
-
- kb_message_suffix = ""
- if final_document_id:
- try:
- from app.services.confluence import ConfluenceKBSyncService
-
- kb_service = ConfluenceKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=final_document_id,
- page_id=final_page_id,
- user_id=user_id,
- search_space_id=search_space_id,
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Confluence connector is invalid.",
+ }
+
+ try:
+ client = ConfluenceHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ api_result = await client.update_page(
+ page_id=final_page_id,
+ title=final_title,
+ body=final_content,
+ version_number=final_version + 1,
+ )
+ await client.close()
+ except Exception as api_err:
+ if (
+ "http 403" in str(api_err).lower()
+ or "status code 403" in str(api_err).lower()
+ ):
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ page_links = (
+ api_result.get("_links", {}) if isinstance(api_result, dict) else {}
+ )
+ page_url = ""
+ if page_links.get("base") and page_links.get("webui"):
+ page_url = f"{page_links['base']}{page_links['webui']}"
+
+ kb_message_suffix = ""
+ if final_document_id:
+ try:
+ from app.services.confluence import ConfluenceKBSyncService
+
+ kb_service = ConfluenceKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=final_document_id,
+ page_id=final_page_id,
+ user_id=user_id,
+ search_space_id=search_space_id,
)
- else:
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = (
+ " The knowledge base will be updated in the next sync."
+ )
+ except Exception as kb_err:
+ logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
- except Exception as kb_err:
- logger.warning(f"KB sync after update failed: {kb_err}")
- kb_message_suffix = (
- " The knowledge base will be updated in the next sync."
- )
- return {
- "status": "success",
- "page_id": final_page_id,
- "page_url": page_url,
- "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
- }
+ return {
+ "status": "success",
+ "page_id": final_page_id,
+ "page_url": page_url,
+ "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py
index 5675a42e6..6420a90e6 100644
--- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py
+++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py
@@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
@@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
search_space_id: int,
user_id: str,
) -> StructuredTool:
+ """Factory function to create the get_connected_accounts tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to scope account discovery to.
+ user_id: User ID to scope account discovery to.
+
+ Returns:
+ Configured StructuredTool for connected-accounts discovery.
+ """
+ del db_session # per-call session — see docstring
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
@@ -68,40 +85,41 @@ def create_get_connected_accounts_tool(
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type == connector_type,
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type == connector_type,
+ )
)
- )
- connectors = result.scalars().all()
+ connectors = result.scalars().all()
- if not connectors:
- return [
- {
- "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
+ if not connectors:
+ return [
+ {
+ "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
+ }
+ ]
+
+ is_multi = len(connectors) > 1
+
+ accounts: list[dict[str, Any]] = []
+ for conn in connectors:
+ cfg = conn.config or {}
+ entry: dict[str, Any] = {
+ "connector_id": conn.id,
+ "display_name": _extract_display_name(conn),
+ "service": service,
}
- ]
+ if is_multi:
+ entry["tool_prefix"] = f"{service}_{conn.id}"
+ for key in svc_cfg.account_metadata_keys:
+ if key in cfg:
+ entry[key] = cfg[key]
+ accounts.append(entry)
- is_multi = len(connectors) > 1
-
- accounts: list[dict[str, Any]] = []
- for conn in connectors:
- cfg = conn.config or {}
- entry: dict[str, Any] = {
- "connector_id": conn.id,
- "display_name": _extract_display_name(conn),
- "service": service,
- }
- if is_multi:
- entry["tool_prefix"] = f"{service}_{conn.id}"
- for key in svc_cfg.account_metadata_keys:
- if key in cfg:
- entry[key] = cfg[key]
- accounts.append(entry)
-
- return accounts
+ return accounts
return StructuredTool(
name="get_connected_accounts",
diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py
index 3cc99ac17..01159a261 100644
--- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py
+++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the list_discord_channels tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured list_discord_channels tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server.
@@ -22,59 +41,60 @@ def create_list_discord_channels_tool(
Returns:
Dictionary with status and a list of channels (id, name).
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
}
try:
- connector = await get_discord_connector(
- db_session, search_space_id, user_id
- )
- if not connector:
- return {"status": "error", "message": "No Discord connector found."}
-
- guild_id = get_guild_id(connector)
- if not guild_id:
- return {
- "status": "error",
- "message": "No guild ID in Discord connector config.",
- }
-
- token = get_bot_token(connector)
-
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{DISCORD_API}/guilds/{guild_id}/channels",
- headers={"Authorization": f"Bot {token}"},
- timeout=15.0,
+ async with async_session_maker() as db_session:
+ connector = await get_discord_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Discord connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Discord bot token is invalid.",
- "connector_type": "discord",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Discord API error: {resp.status_code}",
- }
+ guild_id = get_guild_id(connector)
+ if not guild_id:
+ return {
+ "status": "error",
+ "message": "No guild ID in Discord connector config.",
+ }
- # Type 0 = text channel
- channels = [
- {"id": ch["id"], "name": ch["name"]}
- for ch in resp.json()
- if ch.get("type") == 0
- ]
- return {
- "status": "success",
- "guild_id": guild_id,
- "channels": channels,
- "total": len(channels),
- }
+ token = get_bot_token(connector)
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ f"{DISCORD_API}/guilds/{guild_id}/channels",
+ headers={"Authorization": f"Bot {token}"},
+ timeout=15.0,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Discord bot token is invalid.",
+ "connector_type": "discord",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Discord API error: {resp.status_code}",
+ }
+
+ # Type 0 = text channel
+ channels = [
+ {"id": ch["id"], "name": ch["name"]}
+ for ch in resp.json()
+ if ch.get("type") == 0
+ ]
+ return {
+ "status": "success",
+ "guild_id": guild_id,
+ "channels": channels,
+ "total": len(channels),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py
index d8bf989a1..88d6cdd49 100644
--- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py
+++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_discord_messages tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_discord_messages tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_discord_messages(
channel_id: str,
@@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
Dictionary with status and a list of messages including
id, author, content, timestamp.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
@@ -39,55 +58,56 @@ def create_read_discord_messages_tool(
limit = min(limit, 50)
try:
- connector = await get_discord_connector(
- db_session, search_space_id, user_id
- )
- if not connector:
- return {"status": "error", "message": "No Discord connector found."}
-
- token = get_bot_token(connector)
-
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{DISCORD_API}/channels/{channel_id}/messages",
- headers={"Authorization": f"Bot {token}"},
- params={"limit": limit},
- timeout=15.0,
+ async with async_session_maker() as db_session:
+ connector = await get_discord_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Discord connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Discord bot token is invalid.",
- "connector_type": "discord",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Bot lacks permission to read this channel.",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Discord API error: {resp.status_code}",
- }
+ token = get_bot_token(connector)
- messages = [
- {
- "id": m["id"],
- "author": m.get("author", {}).get("username", "Unknown"),
- "content": m.get("content", ""),
- "timestamp": m.get("timestamp", ""),
- }
- for m in resp.json()
- ]
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ f"{DISCORD_API}/channels/{channel_id}/messages",
+ headers={"Authorization": f"Bot {token}"},
+ params={"limit": limit},
+ timeout=15.0,
+ )
- return {
- "status": "success",
- "channel_id": channel_id,
- "messages": messages,
- "total": len(messages),
- }
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Discord bot token is invalid.",
+ "connector_type": "discord",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Bot lacks permission to read this channel.",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Discord API error: {resp.status_code}",
+ }
+
+ messages = [
+ {
+ "id": m["id"],
+ "author": m.get("author", {}).get("username", "Unknown"),
+ "content": m.get("content", ""),
+ "timestamp": m.get("timestamp", ""),
+ }
+ for m in resp.json()
+ ]
+
+ return {
+ "status": "success",
+ "channel_id": channel_id,
+ "messages": messages,
+ "total": len(messages),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py
index 236cd017a..5fe6fde35 100644
--- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py
+++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
@@ -17,6 +18,23 @@ def create_send_discord_message_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the send_discord_message tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured send_discord_message tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def send_discord_message(
channel_id: str,
@@ -34,7 +52,7 @@ def create_send_discord_message_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
@@ -47,64 +65,65 @@ def create_send_discord_message_tool(
}
try:
- connector = await get_discord_connector(
- db_session, search_space_id, user_id
- )
- if not connector:
- return {"status": "error", "message": "No Discord connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_discord_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Discord connector found."}
- result = request_approval(
- action_type="discord_send_message",
- tool_name="send_discord_message",
- params={"channel_id": channel_id, "content": content},
- context={"connector_id": connector.id},
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Message was not sent.",
- }
-
- final_content = result.params.get("content", content)
- final_channel = result.params.get("channel_id", channel_id)
-
- token = get_bot_token(connector)
-
- async with httpx.AsyncClient() as client:
- resp = await client.post(
- f"{DISCORD_API}/channels/{final_channel}/messages",
- headers={
- "Authorization": f"Bot {token}",
- "Content-Type": "application/json",
- },
- json={"content": final_content},
- timeout=15.0,
+ result = request_approval(
+ action_type="discord_send_message",
+ tool_name="send_discord_message",
+ params={"channel_id": channel_id, "content": content},
+ context={"connector_id": connector.id},
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Discord bot token is invalid.",
- "connector_type": "discord",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Bot lacks permission to send messages in this channel.",
- }
- if resp.status_code not in (200, 201):
- return {
- "status": "error",
- "message": f"Discord API error: {resp.status_code}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Message was not sent.",
+ }
- msg_data = resp.json()
- return {
- "status": "success",
- "message_id": msg_data.get("id"),
- "message": f"Message sent to channel {final_channel}.",
- }
+ final_content = result.params.get("content", content)
+ final_channel = result.params.get("channel_id", channel_id)
+
+ token = get_bot_token(connector)
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ f"{DISCORD_API}/channels/{final_channel}/messages",
+ headers={
+ "Authorization": f"Bot {token}",
+ "Content-Type": "application/json",
+ },
+ json={"content": final_content},
+ timeout=15.0,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Discord bot token is invalid.",
+ "connector_type": "discord",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Bot lacks permission to send messages in this channel.",
+ }
+ if resp.status_code not in (200, 201):
+ return {
+ "status": "error",
+ "message": f"Discord API error: {resp.status_code}",
+ }
+
+ msg_data = resp.json()
+ return {
+ "status": "success",
+ "message_id": msg_data.get("id"),
+ "message": f"Message sent to channel {final_channel}.",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py
index 22d8a8a27..7aae034cc 100644
--- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py
@@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_dropbox_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_dropbox_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_dropbox_file(
name: str,
@@ -82,184 +99,191 @@ def create_create_dropbox_file_tool(
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.DROPBOX_CONNECTOR,
- )
- )
- connectors = result.scalars().all()
-
- if not connectors:
- return {
- "status": "error",
- "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
- }
-
- accounts = []
- for c in connectors:
- cfg = c.config or {}
- accounts.append(
- {
- "id": c.id,
- "name": c.name,
- "user_email": cfg.get("user_email"),
- "auth_expired": cfg.get("auth_expired", False),
- }
- )
-
- if all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected Dropbox accounts need re-authentication.",
- "connector_type": "dropbox",
- }
-
- parent_folders: dict[int, list[dict[str, str]]] = {}
- for acc in accounts:
- cid = acc["id"]
- if acc.get("auth_expired"):
- parent_folders[cid] = []
- continue
- try:
- client = DropboxClient(session=db_session, connector_id=cid)
- items, err = await client.list_folder("")
- if err:
- logger.warning(
- "Failed to list folders for connector %s: %s", cid, err
- )
- parent_folders[cid] = []
- else:
- parent_folders[cid] = [
- {
- "folder_path": item.get("path_lower", ""),
- "name": item["name"],
- }
- for item in items
- if item.get(".tag") == "folder" and item.get("name")
- ]
- except Exception:
- logger.warning(
- "Error fetching folders for connector %s", cid, exc_info=True
- )
- parent_folders[cid] = []
-
- context: dict[str, Any] = {
- "accounts": accounts,
- "parent_folders": parent_folders,
- "supported_types": _SUPPORTED_TYPES,
- }
-
- result = request_approval(
- action_type="dropbox_file_creation",
- tool_name="create_dropbox_file",
- params={
- "name": name,
- "file_type": file_type,
- "content": content,
- "connector_id": None,
- "parent_folder_path": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_name = result.params.get("name", name)
- final_file_type = result.params.get("file_type", file_type)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get("connector_id")
- final_parent_folder_path = result.params.get("parent_folder_path")
-
- if not final_name or not final_name.strip():
- return {"status": "error", "message": "File name cannot be empty."}
-
- final_name = _ensure_extension(final_name, final_file_type)
-
- if final_connector_id is not None:
+ async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
- connector = result.scalars().first()
- else:
- connector = connectors[0]
+ connectors = result.scalars().all()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Dropbox connector is invalid.",
+ if not connectors:
+ return {
+ "status": "error",
+ "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
+ }
+
+ accounts = []
+ for c in connectors:
+ cfg = c.config or {}
+ accounts.append(
+ {
+ "id": c.id,
+ "name": c.name,
+ "user_email": cfg.get("user_email"),
+ "auth_expired": cfg.get("auth_expired", False),
+ }
+ )
+
+ if all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected Dropbox accounts need re-authentication.",
+ "connector_type": "dropbox",
+ }
+
+ parent_folders: dict[int, list[dict[str, str]]] = {}
+ for acc in accounts:
+ cid = acc["id"]
+ if acc.get("auth_expired"):
+ parent_folders[cid] = []
+ continue
+ try:
+ client = DropboxClient(session=db_session, connector_id=cid)
+ items, err = await client.list_folder("")
+ if err:
+ logger.warning(
+ "Failed to list folders for connector %s: %s", cid, err
+ )
+ parent_folders[cid] = []
+ else:
+ parent_folders[cid] = [
+ {
+ "folder_path": item.get("path_lower", ""),
+ "name": item["name"],
+ }
+ for item in items
+ if item.get(".tag") == "folder" and item.get("name")
+ ]
+ except Exception:
+ logger.warning(
+ "Error fetching folders for connector %s",
+ cid,
+ exc_info=True,
+ )
+ parent_folders[cid] = []
+
+ context: dict[str, Any] = {
+ "accounts": accounts,
+ "parent_folders": parent_folders,
+ "supported_types": _SUPPORTED_TYPES,
}
- client = DropboxClient(session=db_session, connector_id=connector.id)
-
- parent_path = final_parent_folder_path or ""
- file_path = (
- f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
- )
-
- if final_file_type == "paper":
- created = await client.create_paper_doc(file_path, final_content or "")
- file_id = created.get("file_id", "")
- web_url = created.get("url", "")
- else:
- docx_bytes = _markdown_to_docx(final_content or "")
- created = await client.upload_file(
- file_path, docx_bytes, mode="add", autorename=True
+ result = request_approval(
+ action_type="dropbox_file_creation",
+ tool_name="create_dropbox_file",
+ params={
+ "name": name,
+ "file_type": file_type,
+ "content": content,
+ "connector_id": None,
+ "parent_folder_path": None,
+ },
+ context=context,
)
- file_id = created.get("id", "")
- web_url = ""
- logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
- kb_message_suffix = ""
- try:
- from app.services.dropbox import DropboxKBSyncService
+ final_name = result.params.get("name", name)
+ final_file_type = result.params.get("file_type", file_type)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get("connector_id")
+ final_parent_folder_path = result.params.get("parent_folder_path")
- kb_service = DropboxKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- file_id=file_id,
- file_name=final_name,
- file_path=file_path,
- web_url=web_url,
- content=final_content,
- connector_id=connector.id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ if not final_name or not final_name.strip():
+ return {"status": "error", "message": "File name cannot be empty."}
+
+ final_name = _ensure_extension(final_name, final_file_type)
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.DROPBOX_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
else:
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ connector = connectors[0]
- return {
- "status": "success",
- "file_id": file_id,
- "name": final_name,
- "web_url": web_url,
- "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
- }
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Dropbox connector is invalid.",
+ }
+
+ client = DropboxClient(session=db_session, connector_id=connector.id)
+
+ parent_path = final_parent_folder_path or ""
+ file_path = (
+ f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
+ )
+
+ if final_file_type == "paper":
+ created = await client.create_paper_doc(
+ file_path, final_content or ""
+ )
+ file_id = created.get("file_id", "")
+ web_url = created.get("url", "")
+ else:
+ docx_bytes = _markdown_to_docx(final_content or "")
+ created = await client.upload_file(
+ file_path, docx_bytes, mode="add", autorename=True
+ )
+ file_id = created.get("id", "")
+ web_url = ""
+
+ logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
+
+ kb_message_suffix = ""
+ try:
+ from app.services.dropbox import DropboxKBSyncService
+
+ kb_service = DropboxKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ file_id=file_id,
+ file_name=final_name,
+ file_path=file_path,
+ web_url=web_url,
+ content=final_content,
+ connector_id=connector.id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "file_id": file_id,
+ "name": final_name,
+ "web_url": web_url,
+ "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py
index 12559b57a..0e59e49db 100644
--- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py
@@ -13,6 +13,7 @@ from app.db import (
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
+ async_session_maker,
)
logger = logging.getLogger(__name__)
@@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_dropbox_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_dropbox_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_dropbox_file(
file_name: str,
@@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool(
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
- doc_result = await db_session.execute(
- select(Document)
- .join(
- SearchSourceConnector,
- Document.connector_id == SearchSourceConnector.id,
- )
- .filter(
- and_(
- Document.search_space_id == search_space_id,
- Document.document_type == DocumentType.DROPBOX_FILE,
- func.lower(Document.title) == func.lower(file_name),
- SearchSourceConnector.user_id == user_id,
- )
- )
- .order_by(Document.updated_at.desc().nullslast())
- .limit(1)
- )
- document = doc_result.scalars().first()
-
- if not document:
+ async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
@@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
- func.lower(
- cast(
- Document.document_metadata["dropbox_file_name"],
- String,
- )
- )
- == func.lower(file_name),
+ func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
@@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool(
)
document = doc_result.scalars().first()
- if not document:
- return {
- "status": "not_found",
- "message": (
- f"File '{file_name}' not found in your indexed Dropbox files. "
- "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
- "or (3) the file name is different."
- ),
- }
-
- if not document.connector_id:
- return {
- "status": "error",
- "message": "Document has no associated connector.",
- }
-
- meta = document.document_metadata or {}
- file_path = meta.get("dropbox_path")
- file_id = meta.get("dropbox_file_id")
- document_id = document.id
-
- if not file_path:
- return {
- "status": "error",
- "message": "File path is missing. Please re-index the file.",
- }
-
- conn_result = await db_session.execute(
- select(SearchSourceConnector).filter(
- and_(
- SearchSourceConnector.id == document.connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.DROPBOX_CONNECTOR,
+ if not document:
+ doc_result = await db_session.execute(
+ select(Document)
+ .join(
+ SearchSourceConnector,
+ Document.connector_id == SearchSourceConnector.id,
+ )
+ .filter(
+ and_(
+ Document.search_space_id == search_space_id,
+ Document.document_type == DocumentType.DROPBOX_FILE,
+ func.lower(
+ cast(
+ Document.document_metadata["dropbox_file_name"],
+ String,
+ )
+ )
+ == func.lower(file_name),
+ SearchSourceConnector.user_id == user_id,
+ )
+ )
+ .order_by(Document.updated_at.desc().nullslast())
+ .limit(1)
)
- )
- )
- connector = conn_result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Dropbox connector not found or access denied.",
- }
+ document = doc_result.scalars().first()
- cfg = connector.config or {}
- if cfg.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "dropbox",
- }
+ if not document:
+ return {
+ "status": "not_found",
+ "message": (
+ f"File '{file_name}' not found in your indexed Dropbox files. "
+ "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
+ "or (3) the file name is different."
+ ),
+ }
- context = {
- "file": {
- "file_id": file_id,
- "file_path": file_path,
- "name": file_name,
- "document_id": document_id,
- },
- "account": {
- "id": connector.id,
- "name": connector.name,
- "user_email": cfg.get("user_email"),
- },
- }
+ if not document.connector_id:
+ return {
+ "status": "error",
+ "message": "Document has no associated connector.",
+ }
- result = request_approval(
- action_type="dropbox_file_trash",
- tool_name="delete_dropbox_file",
- params={
- "file_path": file_path,
- "connector_id": connector.id,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ meta = document.document_metadata or {}
+ file_path = meta.get("dropbox_path")
+ file_id = meta.get("dropbox_file_id")
+ document_id = document.id
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
+ if not file_path:
+ return {
+ "status": "error",
+ "message": "File path is missing. Please re-index the file.",
+ }
- final_file_path = result.params.get("file_path", file_path)
- final_connector_id = result.params.get("connector_id", connector.id)
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if final_connector_id != connector.id:
- result = await db_session.execute(
+ conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
- SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
@@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool(
)
)
)
- validated_connector = result.scalars().first()
- if not validated_connector:
+ connector = conn_result.scalars().first()
+ if not connector:
return {
"status": "error",
- "message": "Selected Dropbox connector is invalid or has been disconnected.",
+ "message": "Dropbox connector not found or access denied.",
}
- actual_connector_id = validated_connector.id
- else:
- actual_connector_id = connector.id
- logger.info(
- f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
- )
+ cfg = connector.config or {}
+ if cfg.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "dropbox",
+ }
- client = DropboxClient(session=db_session, connector_id=actual_connector_id)
- await client.delete_file(final_file_path)
+ context = {
+ "file": {
+ "file_id": file_id,
+ "file_path": file_path,
+ "name": file_name,
+ "document_id": document_id,
+ },
+ "account": {
+ "id": connector.id,
+ "name": connector.name,
+ "user_email": cfg.get("user_email"),
+ },
+ }
- logger.info(f"Dropbox file deleted: path={final_file_path}")
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "file_id": file_id,
- "message": f"Successfully deleted '{file_name}' from Dropbox.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- doc = doc_result.scalars().first()
- if doc:
- await db_session.delete(doc)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"File deleted, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ result = request_approval(
+ action_type="dropbox_file_trash",
+ tool_name="delete_dropbox_file",
+ params={
+ "file_path": file_path,
+ "connector_id": connector.id,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- return trash_result
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_file_path = result.params.get("file_path", file_path)
+ final_connector_id = result.params.get("connector_id", connector.id)
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ if final_connector_id != connector.id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ and_(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id
+ == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.DROPBOX_CONNECTOR,
+ )
+ )
+ )
+ validated_connector = result.scalars().first()
+ if not validated_connector:
+ return {
+ "status": "error",
+ "message": "Selected Dropbox connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = validated_connector.id
+ else:
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
+ )
+
+ client = DropboxClient(
+ session=db_session, connector_id=actual_connector_id
+ )
+ await client.delete_file(final_file_path)
+
+ logger.info(f"Dropbox file deleted: path={final_file_path}")
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "file_id": file_id,
+ "message": f"Successfully deleted '{file_name}' from Dropbox.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ doc = doc_result.scalars().first()
+ if doc:
+ await db_session.delete(doc)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"File deleted, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
index 7e9ddf7d3..c88b48d2d 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_gmail_draft tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_gmail_draft tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_gmail_draft(
to: str,
@@ -57,267 +75,276 @@ def create_create_gmail_draft_tool(
"""
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
+ )
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Gmail accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
-
- logger.info(
- f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
- )
- result = request_approval(
- action_type="gmail_draft_creation",
- tool_name="create_gmail_draft",
- params={
- "to": to,
- "subject": subject,
- "body": body,
- "cc": cc,
- "bcc": bcc,
- "connector_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
- }
-
- final_to = result.params.get("to", to)
- final_subject = result.params.get("subject", subject)
- final_body = result.params.get("body", body)
- final_cc = result.params.get("cc", cc)
- final_bcc = result.params.get("bcc", bcc)
- final_connector_id = result.params.get("connector_id")
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning("All Gmail accounts have expired authentication")
return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
+ "status": "auth_error",
+ "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
}
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
+
+ logger.info(
+ f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
+ )
+ result = request_approval(
+ action_type="gmail_draft_creation",
+ tool_name="create_gmail_draft",
+ params={
+ "to": to,
+ "subject": subject,
+ "body": body,
+ "cc": cc,
+ "bcc": bcc,
+ "connector_id": None,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
+ }
+
+ final_to = result.params.get("to", to)
+ final_subject = result.params.get("subject", subject)
+ final_body = result.params.get("body", body)
+ final_cc = result.params.get("cc", cc)
+ final_bcc = result.params.get("bcc", bcc)
+ final_connector_id = result.params.get("connector_id")
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
- }
- actual_connector_id = connector.id
-
- logger.info(
- f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
- )
-
- is_composio_gmail = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- )
- if is_composio_gmail:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
- }
- else:
- from google.oauth2.credentials import Credentials
-
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
- )
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
+ else:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
)
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ }
+ actual_connector_id = connector.id
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
- message = MIMEText(final_body)
- message["to"] = final_to
- message["subject"] = final_subject
- if final_cc:
- message["cc"] = final_cc
- if final_bcc:
- message["bcc"] = final_bcc
- raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
-
- try:
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
if is_composio_gmail:
- from app.agents.new_chat.tools.gmail.composio_helpers import (
- execute_composio_gmail_tool,
- split_recipients,
- )
-
- created, error = await execute_composio_gmail_tool(
- connector,
- user_id,
- "GMAIL_CREATE_EMAIL_DRAFT",
- {
- "user_id": "me",
- "recipient_email": final_to,
- "subject": final_subject,
- "body": final_body,
- "cc": split_recipients(final_cc),
- "bcc": split_recipients(final_bcc),
- "is_html": False,
- },
- )
- if error:
- raise RuntimeError(error)
- if not isinstance(created, dict):
- created = {}
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
else:
- from googleapiclient.discovery import build
+ from google.oauth2.credentials import Credentials
- gmail_service = build("gmail", "v1", credentials=creds)
- created = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .drafts()
- .create(userId="me", body={"message": {"raw": raw}})
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
)
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ message = MIMEText(final_body)
+ message["to"] = final_to
+ message["subject"] = final_subject
+ if final_cc:
+ message["cc"] = final_cc
+ if final_bcc:
+ message["bcc"] = final_bcc
+ raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
)
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
+
+ created, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_CREATE_EMAIL_DRAFT",
+ {
+ "user_id": "me",
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(created, dict):
+ created = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ created = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .drafts()
+ .create(userId="me", body={"message": {"raw": raw}})
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
- logger.info(f"Gmail draft created: id={created.get('id')}")
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- kb_message_suffix = ""
- try:
- from app.services.gmail import GmailKBSyncService
+ logger.info(f"Gmail draft created: id={created.get('id')}")
- kb_service = GmailKBSyncService(db_session)
- draft_message = created.get("message", {})
- kb_result = await kb_service.sync_after_create(
- message_id=draft_message.get("id", ""),
- thread_id=draft_message.get("threadId", ""),
- subject=final_subject,
- sender="me",
- date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- body_text=final_body,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- draft_id=created.get("id"),
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
+ kb_message_suffix = ""
+ try:
+ from app.services.gmail import GmailKBSyncService
+
+ kb_service = GmailKBSyncService(db_session)
+ draft_message = created.get("message", {})
+ kb_result = await kb_service.sync_after_create(
+ message_id=draft_message.get("id", ""),
+ thread_id=draft_message.get("threadId", ""),
+ subject=final_subject,
+ sender="me",
+ date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ body_text=final_body,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ draft_id=created.get("id"),
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "draft_id": created.get("id"),
- "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
- }
+ return {
+ "status": "success",
+ "draft_id": created.get("id"),
+ "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
index 1964181e4..464713591 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py
@@ -5,7 +5,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_gmail_email tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_gmail_email tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID.
@@ -32,108 +49,115 @@ def create_read_gmail_email_tool(
Returns:
Dictionary with status and the full email content formatted as markdown.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
- }
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
+ connector = result.scalars().first()
+ if not connector:
return {
"status": "error",
- "message": "Composio connected account ID not found.",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ }
+
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ ):
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found.",
+ }
+
+ from app.agents.new_chat.tools.gmail.search_emails import (
+ _format_gmail_summary,
+ )
+ from app.services.composio_service import ComposioService
+
+ service = ComposioService()
+ detail, error = await service.get_gmail_message_detail(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ message_id=message_id,
+ )
+ if error:
+ return {"status": "error", "message": error}
+ if not detail:
+ return {
+ "status": "not_found",
+ "message": f"Email with ID '{message_id}' not found.",
+ }
+
+ summary = _format_gmail_summary(detail)
+ content = (
+ f"# {summary['subject']}\n\n"
+ f"**From:** {summary['from']}\n"
+ f"**To:** {summary['to']}\n"
+ f"**Date:** {summary['date']}\n\n"
+ f"## Message Content\n\n"
+ f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
+ f"## Message Details\n\n"
+ f"- **Message ID:** {summary['message_id']}\n"
+ f"- **Thread ID:** {summary['thread_id']}\n"
+ )
+ return {
+ "status": "success",
+ "message_id": summary["message_id"] or message_id,
+ "content": content,
}
from app.agents.new_chat.tools.gmail.search_emails import (
- _format_gmail_summary,
+ _build_credentials,
)
- from app.services.composio_service import ComposioService
- service = ComposioService()
- detail, error = await service.get_gmail_message_detail(
- connected_account_id=cca_id,
- entity_id=f"surfsense_{user_id}",
- message_id=message_id,
+ creds = _build_credentials(connector)
+
+ from app.connectors.google_gmail_connector import GoogleGmailConnector
+
+ gmail = GoogleGmailConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
)
+
+ detail, error = await gmail.get_message_details(message_id)
if error:
+ if (
+ "re-authenticate" in error.lower()
+ or "authentication failed" in error.lower()
+ ):
+ return {
+ "status": "auth_error",
+ "message": error,
+ "connector_type": "gmail",
+ }
return {"status": "error", "message": error}
+
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
- summary = _format_gmail_summary(detail)
- content = (
- f"# {summary['subject']}\n\n"
- f"**From:** {summary['from']}\n"
- f"**To:** {summary['to']}\n"
- f"**Date:** {summary['date']}\n\n"
- f"## Message Content\n\n"
- f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
- f"## Message Details\n\n"
- f"- **Message ID:** {summary['message_id']}\n"
- f"- **Thread ID:** {summary['thread_id']}\n"
- )
+ content = gmail.format_message_to_markdown(detail)
+
return {
"status": "success",
- "message_id": summary["message_id"] or message_id,
+ "message_id": message_id,
"content": content,
}
- from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
-
- creds = _build_credentials(connector)
-
- from app.connectors.google_gmail_connector import GoogleGmailConnector
-
- gmail = GoogleGmailConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
-
- detail, error = await gmail.get_message_details(message_id)
- if error:
- if (
- "re-authenticate" in error.lower()
- or "authentication failed" in error.lower()
- ):
- return {
- "status": "auth_error",
- "message": error,
- "connector_type": "gmail",
- }
- return {"status": "error", "message": error}
-
- if not detail:
- return {
- "status": "not_found",
- "message": f"Email with ID '{message_id}' not found.",
- }
-
- content = gmail.format_message_to_markdown(detail)
-
- return {"status": "success", "message_id": message_id, "content": content}
-
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
index 59886159a..3ce154c53 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py
@@ -6,7 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -124,6 +124,23 @@ def create_search_gmail_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the search_gmail tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured search_gmail tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def search_gmail(
query: str,
@@ -142,91 +159,92 @@ def create_search_gmail_tool(
Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20)
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
- }
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- ):
- return await _search_composio_gmail(
- connector, str(user_id), query, max_results
- )
-
- creds = _build_credentials(connector)
-
- from app.connectors.google_gmail_connector import GoogleGmailConnector
-
- gmail = GoogleGmailConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
-
- messages_list, error = await gmail.get_messages_list(
- max_results=max_results, query=query
- )
- if error:
- if (
- "re-authenticate" in error.lower()
- or "authentication failed" in error.lower()
- ):
+ connector = result.scalars().first()
+ if not connector:
return {
- "status": "auth_error",
- "message": error,
- "connector_type": "gmail",
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
- return {"status": "error", "message": error}
- if not messages_list:
- return {
- "status": "success",
- "emails": [],
- "total": 0,
- "message": "No emails found.",
- }
+ if (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ ):
+ return await _search_composio_gmail(
+ connector, str(user_id), query, max_results
+ )
- emails = []
- for msg in messages_list:
- detail, err = await gmail.get_message_details(msg["id"])
- if err:
- continue
- headers = {
- h["name"].lower(): h["value"]
- for h in detail.get("payload", {}).get("headers", [])
- }
- emails.append(
- {
- "message_id": detail.get("id"),
- "thread_id": detail.get("threadId"),
- "subject": headers.get("subject", "No Subject"),
- "from": headers.get("from", "Unknown"),
- "to": headers.get("to", ""),
- "date": headers.get("date", ""),
- "snippet": detail.get("snippet", ""),
- "labels": detail.get("labelIds", []),
- }
+ creds = _build_credentials(connector)
+
+ from app.connectors.google_gmail_connector import GoogleGmailConnector
+
+ gmail = GoogleGmailConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
)
- return {"status": "success", "emails": emails, "total": len(emails)}
+ messages_list, error = await gmail.get_messages_list(
+ max_results=max_results, query=query
+ )
+ if error:
+ if (
+ "re-authenticate" in error.lower()
+ or "authentication failed" in error.lower()
+ ):
+ return {
+ "status": "auth_error",
+ "message": error,
+ "connector_type": "gmail",
+ }
+ return {"status": "error", "message": error}
+
+ if not messages_list:
+ return {
+ "status": "success",
+ "emails": [],
+ "total": 0,
+ "message": "No emails found.",
+ }
+
+ emails = []
+ for msg in messages_list:
+ detail, err = await gmail.get_message_details(msg["id"])
+ if err:
+ continue
+ headers = {
+ h["name"].lower(): h["value"]
+ for h in detail.get("payload", {}).get("headers", [])
+ }
+ emails.append(
+ {
+ "message_id": detail.get("id"),
+ "thread_id": detail.get("threadId"),
+ "subject": headers.get("subject", "No Subject"),
+ "from": headers.get("from", "Unknown"),
+ "to": headers.get("to", ""),
+ "date": headers.get("date", ""),
+ "snippet": detail.get("snippet", ""),
+ "labels": detail.get("labelIds", []),
+ }
+ )
+
+ return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
index 79ff2d9c7..4d5aa3bcc 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the send_gmail_email tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured send_gmail_email tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def send_gmail_email(
to: str,
@@ -58,268 +76,277 @@ def create_send_gmail_email_tool(
"""
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
+ )
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Gmail accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
-
- logger.info(
- f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
- )
- result = request_approval(
- action_type="gmail_email_send",
- tool_name="send_gmail_email",
- params={
- "to": to,
- "subject": subject,
- "body": body,
- "cc": cc,
- "bcc": bcc,
- "connector_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
- }
-
- final_to = result.params.get("to", to)
- final_subject = result.params.get("subject", subject)
- final_body = result.params.get("body", body)
- final_cc = result.params.get("cc", cc)
- final_bcc = result.params.get("bcc", bcc)
- final_connector_id = result.params.get("connector_id")
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning("All Gmail accounts have expired authentication")
return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
+ "status": "auth_error",
+ "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
}
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
+
+ logger.info(
+ f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
+ )
+ result = request_approval(
+ action_type="gmail_email_send",
+ tool_name="send_gmail_email",
+ params={
+ "to": to,
+ "subject": subject,
+ "body": body,
+ "cc": cc,
+ "bcc": bcc,
+ "connector_id": None,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
+ }
+
+ final_to = result.params.get("to", to)
+ final_subject = result.params.get("subject", subject)
+ final_body = result.params.get("body", body)
+ final_cc = result.params.get("cc", cc)
+ final_bcc = result.params.get("bcc", bcc)
+ final_connector_id = result.params.get("connector_id")
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
- }
- actual_connector_id = connector.id
-
- logger.info(
- f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
- )
-
- is_composio_gmail = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- )
- if is_composio_gmail:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
- }
- else:
- from google.oauth2.credentials import Credentials
-
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
- )
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
+ else:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
)
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
+ }
+ actual_connector_id = connector.id
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
- message = MIMEText(final_body)
- message["to"] = final_to
- message["subject"] = final_subject
- if final_cc:
- message["cc"] = final_cc
- if final_bcc:
- message["bcc"] = final_bcc
- raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
-
- try:
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
if is_composio_gmail:
- from app.agents.new_chat.tools.gmail.composio_helpers import (
- execute_composio_gmail_tool,
- split_recipients,
- )
-
- sent, error = await execute_composio_gmail_tool(
- connector,
- user_id,
- "GMAIL_SEND_EMAIL",
- {
- "user_id": "me",
- "recipient_email": final_to,
- "subject": final_subject,
- "body": final_body,
- "cc": split_recipients(final_cc),
- "bcc": split_recipients(final_bcc),
- "is_html": False,
- },
- )
- if error:
- raise RuntimeError(error)
- if not isinstance(sent, dict):
- sent = {}
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
else:
- from googleapiclient.discovery import build
+ from google.oauth2.credentials import Credentials
- gmail_service = build("gmail", "v1", credentials=creds)
- sent = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .messages()
- .send(userId="me", body={"raw": raw})
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
)
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ message = MIMEText(final_body)
+ message["to"] = final_to
+ message["subject"] = final_subject
+ if final_cc:
+ message["cc"] = final_cc
+ if final_bcc:
+ message["bcc"] = final_bcc
+ raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
)
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
+
+ sent, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_SEND_EMAIL",
+ {
+ "user_id": "me",
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(sent, dict):
+ sent = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ sent = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .messages()
+ .send(userId="me", body={"raw": raw})
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
- logger.info(
- f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
- )
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- kb_message_suffix = ""
- try:
- from app.services.gmail import GmailKBSyncService
-
- kb_service = GmailKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- message_id=sent.get("id", ""),
- thread_id=sent.get("threadId", ""),
- subject=final_subject,
- sender="me",
- date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- body_text=final_body,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ logger.info(
+ f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after send failed: {kb_err}")
- kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "message_id": sent.get("id"),
- "thread_id": sent.get("threadId"),
- "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
- }
+ kb_message_suffix = ""
+ try:
+ from app.services.gmail import GmailKBSyncService
+
+ kb_service = GmailKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ message_id=sent.get("id", ""),
+ thread_id=sent.get("threadId", ""),
+ subject=final_subject,
+ sender="me",
+ date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ body_text=final_body,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after send failed: {kb_err}")
+ kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "message_id": sent.get("id"),
+ "thread_id": sent.get("threadId"),
+ "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
index 4e710dc72..95f5b4e6c 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
@@ -7,6 +7,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the trash_gmail_email tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured trash_gmail_email tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def trash_gmail_email(
email_subject_or_id: str,
@@ -55,254 +73,261 @@ def create_trash_gmail_email_tool(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_trash_context(
- search_space_id, user_id, email_subject_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Email not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch trash context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Gmail account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_trash_context(
+ search_space_id, user_id, email_subject_or_id
)
- return {
- "status": "auth_error",
- "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
- email = context["email"]
- message_id = email["message_id"]
- document_id = email.get("document_id")
- connector_id_from_context = context["account"]["id"]
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Email not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch trash context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if not message_id:
- return {
- "status": "error",
- "message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
- }
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Gmail account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
+ }
- logger.info(
- f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="gmail_email_trash",
- tool_name="trash_gmail_email",
- params={
- "message_id": message_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ email = context["email"]
+ message_id = email["message_id"]
+ document_id = email.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
- }
-
- final_message_id = result.params.get("message_id", message_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this email.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
- }
-
- logger.info(
- f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
- )
-
- is_composio_gmail = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- )
- if is_composio_gmail:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
+ if not message_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
+ "message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
}
- else:
- from google.oauth2.credentials import Credentials
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
- )
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="gmail_email_trash",
+ tool_name="trash_gmail_email",
+ params={
+ "message_id": message_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- try:
- if is_composio_gmail:
- from app.agents.new_chat.tools.gmail.composio_helpers import (
- execute_composio_gmail_tool,
- )
-
- _trashed, error = await execute_composio_gmail_tool(
- connector,
- user_id,
- "GMAIL_MOVE_TO_TRASH",
- {"user_id": "me", "message_id": final_message_id},
- )
- if error:
- raise RuntimeError(error)
- else:
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
- await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .messages()
- .trash(userId="me", id=final_message_id)
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {connector.id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- if not connector.config.get("auth_expired"):
- connector.config = {
- **connector.config,
- "auth_expired": True,
- }
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- connector.id,
- exc_info=True,
- )
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": connector.id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
- raise
- logger.info(f"Gmail email trashed: message_id={final_message_id}")
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "message_id": final_message_id,
- "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
-
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"Email trashed, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ final_message_id = result.params.get("message_id", message_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
)
- return trash_result
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this email.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+
+ logger.info(
+ f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
+ )
+
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
+ if is_composio_gmail:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
+ else:
+ from google.oauth2.credentials import Credentials
+
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ )
+
+ _trashed, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_MOVE_TO_TRASH",
+ {"user_id": "me", "message_id": final_message_id},
+ )
+ if error:
+ raise RuntimeError(error)
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .messages()
+ .trash(userId="me", id=final_message_id)
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {connector.id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ if not connector.config.get("auth_expired"):
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": connector.id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(f"Gmail email trashed: message_id={final_message_id}")
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "message_id": final_message_id,
+ "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"Email trashed, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
index 50956f03a..129b7defb 100644
--- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
+++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the update_gmail_draft tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured update_gmail_draft tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_gmail_draft(
draft_subject_or_id: str,
@@ -76,324 +94,329 @@ def create_update_gmail_draft_tool(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
- metadata_service = GmailToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, draft_subject_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Draft not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Gmail account %s has expired authentication",
- account.get("id"),
- )
- return {
- "status": "auth_error",
- "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "gmail",
- }
-
- email = context["email"]
- message_id = email["message_id"]
- document_id = email.get("document_id")
- connector_id_from_context = account["id"]
- draft_id_from_context = context.get("draft_id")
-
- original_subject = email.get("subject", draft_subject_or_id)
- final_subject_default = subject if subject else original_subject
- final_to_default = to if to else ""
-
- logger.info(
- f"Requesting approval for updating Gmail draft: '{original_subject}' "
- f"(message_id={message_id}, draft_id={draft_id_from_context})"
- )
- result = request_approval(
- action_type="gmail_draft_update",
- tool_name="update_gmail_draft",
- params={
- "message_id": message_id,
- "draft_id": draft_id_from_context,
- "to": final_to_default,
- "subject": final_subject_default,
- "body": body,
- "cc": cc,
- "bcc": bcc,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
- }
-
- final_to = result.params.get("to", final_to_default)
- final_subject = result.params.get("subject", final_subject_default)
- final_body = result.params.get("body", body)
- final_cc = result.params.get("cc", cc)
- final_bcc = result.params.get("bcc", bcc)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_draft_id = result.params.get("draft_id", draft_id_from_context)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this draft.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _gmail_types = [
- SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_gmail_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Gmail connector is invalid or has been disconnected.",
- }
-
- logger.info(
- f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
- )
-
- is_composio_gmail = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
- )
- if is_composio_gmail:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Gmail connector.",
- }
- else:
- from google.oauth2.credentials import Credentials
-
- from app.config import config
- from app.utils.oauth_security import TokenEncryption
-
- config_data = dict(connector.config)
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and config.SECRET_KEY:
- token_encryption = TokenEncryption(config.SECRET_KEY)
- if config_data.get("token"):
- config_data["token"] = token_encryption.decrypt_token(
- config_data["token"]
- )
- if config_data.get("refresh_token"):
- config_data["refresh_token"] = token_encryption.decrypt_token(
- config_data["refresh_token"]
- )
- if config_data.get("client_secret"):
- config_data["client_secret"] = token_encryption.decrypt_token(
- config_data["client_secret"]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ async with async_session_maker() as db_session:
+ metadata_service = GmailToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, draft_subject_or_id
)
- # Resolve draft_id if not already available
- if not final_draft_id:
- logger.info(
- f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
- )
- if is_composio_gmail:
- final_draft_id = await _find_composio_draft_id_by_message(
- connector, user_id, message_id
- )
- else:
- from googleapiclient.discovery import build
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Draft not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- gmail_service = build("gmail", "v1", credentials=creds)
- final_draft_id = await _find_draft_id_by_message(
- gmail_service, message_id
- )
-
- if not final_draft_id:
- return {
- "status": "error",
- "message": (
- "Could not find this draft in Gmail. "
- "It may have already been sent or deleted."
- ),
- }
-
- message = MIMEText(final_body)
- if final_to:
- message["to"] = final_to
- message["subject"] = final_subject
- if final_cc:
- message["cc"] = final_cc
- if final_bcc:
- message["bcc"] = final_bcc
- raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
-
- try:
- if is_composio_gmail:
- from app.agents.new_chat.tools.gmail.composio_helpers import (
- execute_composio_gmail_tool,
- split_recipients,
- )
-
- updated, error = await execute_composio_gmail_tool(
- connector,
- user_id,
- "GMAIL_UPDATE_DRAFT",
- {
- "user_id": "me",
- "draft_id": final_draft_id,
- "recipient_email": final_to,
- "subject": final_subject,
- "body": final_body,
- "cc": split_recipients(final_cc),
- "bcc": split_recipients(final_bcc),
- "is_html": False,
- },
- )
- if error:
- raise RuntimeError(error)
- if not isinstance(updated, dict):
- updated = {}
- else:
- from googleapiclient.discovery import build
-
- gmail_service = build("gmail", "v1", credentials=creds)
- updated = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- gmail_service.users()
- .drafts()
- .update(
- userId="me",
- id=final_draft_id,
- body={"message": {"raw": raw}},
- )
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ account = context.get("account", {})
+ if account.get("auth_expired"):
logger.warning(
- f"Insufficient permissions for connector {connector.id}: {api_err}"
+ "Gmail account %s has expired authentication",
+ account.get("id"),
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- if not connector.config.get("auth_expired"):
- connector.config = {
- **connector.config,
- "auth_expired": True,
- }
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- connector.id,
- exc_info=True,
- )
return {
- "status": "insufficient_permissions",
- "connector_id": connector.id,
- "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "auth_error",
+ "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "gmail",
}
- if isinstance(api_err, HttpError) and api_err.resp.status == 404:
+
+ email = context["email"]
+ message_id = email["message_id"]
+ document_id = email.get("document_id")
+ connector_id_from_context = account["id"]
+ draft_id_from_context = context.get("draft_id")
+
+ original_subject = email.get("subject", draft_subject_or_id)
+ final_subject_default = subject if subject else original_subject
+ final_to_default = to if to else ""
+
+ logger.info(
+ f"Requesting approval for updating Gmail draft: '{original_subject}' "
+ f"(message_id={message_id}, draft_id={draft_id_from_context})"
+ )
+ result = request_approval(
+ action_type="gmail_draft_update",
+ tool_name="update_gmail_draft",
+ params={
+ "message_id": message_id,
+ "draft_id": draft_id_from_context,
+ "to": final_to_default,
+ "subject": final_subject_default,
+ "body": body,
+ "cc": cc,
+ "bcc": bcc,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
+ }
+
+ final_to = result.params.get("to", final_to_default)
+ final_subject = result.params.get("subject", final_subject_default)
+ final_body = result.params.get("body", body)
+ final_cc = result.params.get("cc", cc)
+ final_bcc = result.params.get("bcc", bcc)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_draft_id = result.params.get("draft_id", draft_id_from_context)
+
+ if not final_connector_id:
return {
"status": "error",
- "message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
+ "message": "No connector found for this draft.",
}
- raise
- logger.info(f"Gmail draft updated: id={updated.get('id')}")
+ from sqlalchemy.future import select
- kb_message_suffix = ""
- if document_id:
- try:
- from sqlalchemy.future import select as sa_select
- from sqlalchemy.orm.attributes import flag_modified
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- from app.db import Document
+ _gmail_types = [
+ SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
+ ]
- doc_result = await db_session.execute(
- sa_select(Document).filter(Document.id == document_id)
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_gmail_types),
)
- document = doc_result.scalars().first()
- if document:
- document.source_markdown = final_body
- document.title = final_subject
- meta = dict(document.document_metadata or {})
- meta["subject"] = final_subject
- meta["draft_id"] = updated.get("id", final_draft_id)
- updated_msg = updated.get("message", {})
- if updated_msg.get("id"):
- meta["message_id"] = updated_msg["id"]
- document.document_metadata = meta
- flag_modified(document, "document_metadata")
- await db_session.commit()
- kb_message_suffix = (
- " Your knowledge base has also been updated."
- )
- logger.info(
- f"KB document {document_id} updated for draft {final_draft_id}"
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Gmail connector is invalid or has been disconnected.",
+ }
+
+ logger.info(
+ f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
+ )
+
+ is_composio_gmail = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
+ )
+ if is_composio_gmail:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Gmail connector.",
+ }
+ else:
+ from google.oauth2.credentials import Credentials
+
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ config_data = dict(connector.config)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = (
+ token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = (
+ token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ # Resolve draft_id if not already available
+ if not final_draft_id:
+ logger.info(
+ f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
+ )
+ if is_composio_gmail:
+ final_draft_id = await _find_composio_draft_id_by_message(
+ connector, user_id, message_id
)
else:
- kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB update after draft edit failed: {kb_err}")
- await db_session.rollback()
- kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
+ from googleapiclient.discovery import build
- return {
- "status": "success",
- "draft_id": updated.get("id"),
- "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
- }
+ gmail_service = build("gmail", "v1", credentials=creds)
+ final_draft_id = await _find_draft_id_by_message(
+ gmail_service, message_id
+ )
+
+ if not final_draft_id:
+ return {
+ "status": "error",
+ "message": (
+ "Could not find this draft in Gmail. "
+ "It may have already been sent or deleted."
+ ),
+ }
+
+ message = MIMEText(final_body)
+ if final_to:
+ message["to"] = final_to
+ message["subject"] = final_subject
+ if final_cc:
+ message["cc"] = final_cc
+ if final_bcc:
+ message["bcc"] = final_bcc
+ raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
+
+ try:
+ if is_composio_gmail:
+ from app.agents.new_chat.tools.gmail.composio_helpers import (
+ execute_composio_gmail_tool,
+ split_recipients,
+ )
+
+ updated, error = await execute_composio_gmail_tool(
+ connector,
+ user_id,
+ "GMAIL_UPDATE_DRAFT",
+ {
+ "user_id": "me",
+ "draft_id": final_draft_id,
+ "recipient_email": final_to,
+ "subject": final_subject,
+ "body": final_body,
+ "cc": split_recipients(final_cc),
+ "bcc": split_recipients(final_bcc),
+ "is_html": False,
+ },
+ )
+ if error:
+ raise RuntimeError(error)
+ if not isinstance(updated, dict):
+ updated = {}
+ else:
+ from googleapiclient.discovery import build
+
+ gmail_service = build("gmail", "v1", credentials=creds)
+ updated = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ gmail_service.users()
+ .drafts()
+ .update(
+ userId="me",
+ id=final_draft_id,
+ body={"message": {"raw": raw}},
+ )
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {connector.id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ if not connector.config.get("auth_expired"):
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": connector.id,
+ "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ if isinstance(api_err, HttpError) and api_err.resp.status == 404:
+ return {
+ "status": "error",
+ "message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
+ }
+ raise
+
+ logger.info(f"Gmail draft updated: id={updated.get('id')}")
+
+ kb_message_suffix = ""
+ if document_id:
+ try:
+ from sqlalchemy.future import select as sa_select
+ from sqlalchemy.orm.attributes import flag_modified
+
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ sa_select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ document.source_markdown = final_body
+ document.title = final_subject
+ meta = dict(document.document_metadata or {})
+ meta["subject"] = final_subject
+ meta["draft_id"] = updated.get("id", final_draft_id)
+ updated_msg = updated.get("message", {})
+ if updated_msg.get("id"):
+ meta["message_id"] = updated_msg["id"]
+ document.document_metadata = meta
+ flag_modified(document, "document_metadata")
+ await db_session.commit()
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ logger.info(
+ f"KB document {document_id} updated for draft {final_draft_id}"
+ )
+ else:
+ kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB update after draft edit failed: {kb_err}")
+ await db_session.rollback()
+ kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "draft_id": updated.get("id"),
+ "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
index 0a4720f6f..dec92cc8b 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_calendar_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_calendar_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_calendar_event(
summary: str,
@@ -60,284 +78,294 @@ def create_create_calendar_event_tool(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleCalendarToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning(
- "All Google Calendar accounts have expired authentication"
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleCalendarToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- return {
- "status": "auth_error",
- "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_calendar",
- }
- logger.info(
- f"Requesting approval for creating calendar event: summary='{summary}'"
- )
- result = request_approval(
- action_type="google_calendar_event_creation",
- tool_name="create_calendar_event",
- params={
- "summary": summary,
- "start_datetime": start_datetime,
- "end_datetime": end_datetime,
- "description": description,
- "location": location,
- "attendees": attendees,
- "timezone": context.get("timezone"),
- "connector_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
- }
-
- final_summary = result.params.get("summary", summary)
- final_start_datetime = result.params.get("start_datetime", start_datetime)
- final_end_datetime = result.params.get("end_datetime", end_datetime)
- final_description = result.params.get("description", description)
- final_location = result.params.get("location", location)
- final_attendees = result.params.get("attendees", attendees)
- final_connector_id = result.params.get("connector_id")
-
- if not final_summary or not final_summary.strip():
- return {"status": "error", "message": "Event summary cannot be empty."}
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _calendar_types = [
- SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Calendar connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning(
+ "All Google Calendar accounts have expired authentication"
)
+ return {
+ "status": "auth_error",
+ "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_calendar",
+ }
+
+ logger.info(
+ f"Requesting approval for creating calendar event: summary='{summary}'"
)
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
- }
- actual_connector_id = connector.id
-
- logger.info(
- f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
- )
-
- is_composio_calendar = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- )
- if is_composio_calendar:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this connector.",
- }
- else:
- config_data = dict(connector.config)
-
- from app.config import config as app_config
- from app.utils.oauth_security import TokenEncryption
-
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and app_config.SECRET_KEY:
- token_encryption = TokenEncryption(app_config.SECRET_KEY)
- for key in ("token", "refresh_token", "client_secret"):
- if config_data.get(key):
- config_data[key] = token_encryption.decrypt_token(
- config_data[key]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ result = request_approval(
+ action_type="google_calendar_event_creation",
+ tool_name="create_calendar_event",
+ params={
+ "summary": summary,
+ "start_datetime": start_datetime,
+ "end_datetime": end_datetime,
+ "description": description,
+ "location": location,
+ "attendees": attendees,
+ "timezone": context.get("timezone"),
+ "connector_id": None,
+ },
+ context=context,
)
- tz = context.get("timezone", "UTC")
- event_body: dict[str, Any] = {
- "summary": final_summary,
- "start": {"dateTime": final_start_datetime, "timeZone": tz},
- "end": {"dateTime": final_end_datetime, "timeZone": tz},
- }
- if final_description:
- event_body["description"] = final_description
- if final_location:
- event_body["location"] = final_location
- if final_attendees:
- event_body["attendees"] = [
- {"email": e.strip()} for e in final_attendees if e.strip()
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
+ }
+
+ final_summary = result.params.get("summary", summary)
+ final_start_datetime = result.params.get(
+ "start_datetime", start_datetime
+ )
+ final_end_datetime = result.params.get("end_datetime", end_datetime)
+ final_description = result.params.get("description", description)
+ final_location = result.params.get("location", location)
+ final_attendees = result.params.get("attendees", attendees)
+ final_connector_id = result.params.get("connector_id")
+
+ if not final_summary or not final_summary.strip():
+ return {
+ "status": "error",
+ "message": "Event summary cannot be empty.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _calendar_types = [
+ SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
- try:
- if is_composio_calendar:
- from app.services.composio_service import ComposioService
-
- composio_params = {
- "calendar_id": "primary",
- "summary": final_summary,
- "start_datetime": final_start_datetime,
- "end_datetime": final_end_datetime,
- "timezone": tz,
- "attendees": final_attendees or [],
- }
- if final_description:
- composio_params["description"] = final_description
- if final_location:
- composio_params["location"] = final_location
-
- composio_result = await ComposioService().execute_tool(
- connected_account_id=cca_id,
- tool_name="GOOGLECALENDAR_CREATE_EVENT",
- params=composio_params,
- entity_id=f"surfsense_{user_id}",
- )
- if not composio_result.get("success"):
- raise RuntimeError(
- composio_result.get(
- "error", "Unknown Composio Calendar error"
- )
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
)
- created = composio_result.get("data", {})
- if isinstance(created, dict):
- created = created.get("data", created)
- if isinstance(created, dict):
- created = created.get("response_data", created)
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Calendar connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
else:
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
- created = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .insert(calendarId="primary", body=event_body)
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
)
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
+ }
+ actual_connector_id = connector.id
- logger.info(
- f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.google_calendar import GoogleCalendarKBSyncService
-
- kb_service = GoogleCalendarKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- event_id=created.get("id"),
- event_summary=final_summary,
- calendar_id="primary",
- start_time=final_start_datetime,
- end_time=final_end_datetime,
- location=final_location,
- html_link=created.get("htmlLink"),
- description=final_description,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ logger.info(
+ f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "event_id": created.get("id"),
- "html_link": created.get("htmlLink"),
- "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
- }
+ is_composio_calendar = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ )
+ if is_composio_calendar:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
+ else:
+ config_data = dict(connector.config)
+
+ from app.config import config as app_config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and app_config.SECRET_KEY:
+ token_encryption = TokenEncryption(app_config.SECRET_KEY)
+ for key in ("token", "refresh_token", "client_secret"):
+ if config_data.get(key):
+ config_data[key] = token_encryption.decrypt_token(
+ config_data[key]
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ tz = context.get("timezone", "UTC")
+ event_body: dict[str, Any] = {
+ "summary": final_summary,
+ "start": {"dateTime": final_start_datetime, "timeZone": tz},
+ "end": {"dateTime": final_end_datetime, "timeZone": tz},
+ }
+ if final_description:
+ event_body["description"] = final_description
+ if final_location:
+ event_body["location"] = final_location
+ if final_attendees:
+ event_body["attendees"] = [
+ {"email": e.strip()} for e in final_attendees if e.strip()
+ ]
+
+ try:
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_params = {
+ "calendar_id": "primary",
+ "summary": final_summary,
+ "start_datetime": final_start_datetime,
+ "end_datetime": final_end_datetime,
+ "timezone": tz,
+ "attendees": final_attendees or [],
+ }
+ if final_description:
+ composio_params["description"] = final_description
+ if final_location:
+ composio_params["location"] = final_location
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_CREATE_EVENT",
+ params=composio_params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ created = composio_result.get("data", {})
+ if isinstance(created, dict):
+ created = created.get("data", created)
+ if isinstance(created, dict):
+ created = created.get("response_data", created)
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ created = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .insert(calendarId="primary", body=event_body)
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(
+ f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.google_calendar import GoogleCalendarKBSyncService
+
+ kb_service = GoogleCalendarKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ event_id=created.get("id"),
+ event_summary=final_summary,
+ calendar_id="primary",
+ start_time=final_start_datetime,
+ end_time=final_end_datetime,
+ location=final_location,
+ html_link=created.get("htmlLink"),
+ description=final_description,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "event_id": created.get("id"),
+ "html_link": created.get("htmlLink"),
+ "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
index 53596ac0f..e7e891b08 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_calendar_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_calendar_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_calendar_event(
event_title_or_id: str,
@@ -54,252 +72,258 @@ def create_delete_calendar_event_tool(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleCalendarToolMetadataService(db_session)
- context = await metadata_service.get_deletion_context(
- search_space_id, user_id, event_title_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Event not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch deletion context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Google Calendar account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleCalendarToolMetadataService(db_session)
+ context = await metadata_service.get_deletion_context(
+ search_space_id, user_id, event_title_or_id
)
- return {
- "status": "auth_error",
- "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_calendar",
- }
- event = context["event"]
- event_id = event["event_id"]
- document_id = event.get("document_id")
- connector_id_from_context = context["account"]["id"]
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Event not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch deletion context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if not event_id:
- return {
- "status": "error",
- "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
- }
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Google Calendar account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_calendar",
+ }
- logger.info(
- f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="google_calendar_event_deletion",
- tool_name="delete_calendar_event",
- params={
- "event_id": event_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ event = context["event"]
+ event_id = event["event_id"]
+ document_id = event.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
- }
-
- final_event_id = result.params.get("event_id", event_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this event.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _calendar_types = [
- SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Calendar connector is invalid or has been disconnected.",
- }
-
- actual_connector_id = connector.id
-
- logger.info(
- f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
- )
-
- is_composio_calendar = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- )
- if is_composio_calendar:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
+ if not event_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this connector.",
+ "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
- else:
- config_data = dict(connector.config)
- from app.config import config as app_config
- from app.utils.oauth_security import TokenEncryption
-
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and app_config.SECRET_KEY:
- token_encryption = TokenEncryption(app_config.SECRET_KEY)
- for key in ("token", "refresh_token", "client_secret"):
- if config_data.get(key):
- config_data[key] = token_encryption.decrypt_token(
- config_data[key]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="google_calendar_event_deletion",
+ tool_name="delete_calendar_event",
+ params={
+ "event_id": event_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- try:
- if is_composio_calendar:
- from app.services.composio_service import ComposioService
-
- composio_result = await ComposioService().execute_tool(
- connected_account_id=cca_id,
- tool_name="GOOGLECALENDAR_DELETE_EVENT",
- params={"calendar_id": "primary", "event_id": final_event_id},
- entity_id=f"surfsense_{user_id}",
- )
- if not composio_result.get("success"):
- raise RuntimeError(
- composio_result.get(
- "error", "Unknown Composio Calendar error"
- )
- )
- else:
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
- await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .delete(calendarId="primary", eventId=final_event_id)
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
- raise
- logger.info(f"Calendar event deleted: event_id={final_event_id}")
-
- delete_result: dict[str, Any] = {
- "status": "success",
- "event_id": final_event_id,
- "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
-
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- delete_result["warning"] = (
- f"Event deleted, but failed to remove from knowledge base: {e!s}"
- )
-
- delete_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- delete_result["message"] = (
- f"{delete_result.get('message', '')} (also removed from knowledge base)"
+ final_event_id = result.params.get("event_id", event_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
)
- return delete_result
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this event.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _calendar_types = [
+ SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
+ ]
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Calendar connector is invalid or has been disconnected.",
+ }
+
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
+ )
+
+ is_composio_calendar = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ )
+ if is_composio_calendar:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
+ else:
+ config_data = dict(connector.config)
+
+ from app.config import config as app_config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and app_config.SECRET_KEY:
+ token_encryption = TokenEncryption(app_config.SECRET_KEY)
+ for key in ("token", "refresh_token", "client_secret"):
+ if config_data.get(key):
+ config_data[key] = token_encryption.decrypt_token(
+ config_data[key]
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
+ )
+
+ try:
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_DELETE_EVENT",
+ params={
+ "calendar_id": "primary",
+ "event_id": final_event_id,
+ },
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .delete(calendarId="primary", eventId=final_event_id)
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(f"Calendar event deleted: event_id={final_event_id}")
+
+ delete_result: dict[str, Any] = {
+ "status": "success",
+ "event_id": final_event_id,
+ "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ delete_result["warning"] = (
+ f"Event deleted, but failed to remove from knowledge base: {e!s}"
+ )
+
+ delete_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ delete_result["message"] = (
+ f"{delete_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return delete_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
index b5194d15f..e5f18f675 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py
@@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -50,6 +50,23 @@ def create_search_calendar_events_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the search_calendar_events tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured search_calendar_events tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def search_calendar_events(
start_date: str,
@@ -67,7 +84,7 @@ def create_search_calendar_events_tool(
Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Calendar tool not properly configured.",
@@ -76,84 +93,85 @@ def create_search_calendar_events_tool(
max_results = min(max_results, 50)
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
+ )
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
- }
-
- if (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- ):
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
+ connector = result.scalars().first()
+ if not connector:
return {
"status": "error",
- "message": "Composio connected account ID not found for this connector.",
+ "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
- from app.services.composio_service import ComposioService
-
- events_raw, error = await ComposioService().get_calendar_events(
- connected_account_id=cca_id,
- entity_id=f"surfsense_{user_id}",
- time_min=_to_calendar_boundary(start_date, is_end=False),
- time_max=_to_calendar_boundary(end_date, is_end=True),
- max_results=max_results,
- )
- if not events_raw and not error:
- error = "No events found in the specified date range."
- else:
- creds = _build_credentials(connector)
-
- from app.connectors.google_calendar_connector import (
- GoogleCalendarConnector,
- )
-
- cal = GoogleCalendarConnector(
- credentials=creds,
- session=db_session,
- user_id=user_id,
- connector_id=connector.id,
- )
-
- events_raw, error = await cal.get_all_primary_calendar_events(
- start_date=start_date,
- end_date=end_date,
- max_results=max_results,
- )
-
- if error:
if (
- "re-authenticate" in error.lower()
- or "authentication failed" in error.lower()
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
- return {
- "status": "auth_error",
- "message": error,
- "connector_type": "google_calendar",
- }
- if "no events found" in error.lower():
- return {
- "status": "success",
- "events": [],
- "total": 0,
- "message": error,
- }
- return {"status": "error", "message": error}
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
- events = _format_calendar_events(events_raw)
+ from app.services.composio_service import ComposioService
- return {"status": "success", "events": events, "total": len(events)}
+ events_raw, error = await ComposioService().get_calendar_events(
+ connected_account_id=cca_id,
+ entity_id=f"surfsense_{user_id}",
+ time_min=_to_calendar_boundary(start_date, is_end=False),
+ time_max=_to_calendar_boundary(end_date, is_end=True),
+ max_results=max_results,
+ )
+ if not events_raw and not error:
+ error = "No events found in the specified date range."
+ else:
+ creds = _build_credentials(connector)
+
+ from app.connectors.google_calendar_connector import (
+ GoogleCalendarConnector,
+ )
+
+ cal = GoogleCalendarConnector(
+ credentials=creds,
+ session=db_session,
+ user_id=user_id,
+ connector_id=connector.id,
+ )
+
+ events_raw, error = await cal.get_all_primary_calendar_events(
+ start_date=start_date,
+ end_date=end_date,
+ max_results=max_results,
+ )
+
+ if error:
+ if (
+ "re-authenticate" in error.lower()
+ or "authentication failed" in error.lower()
+ ):
+ return {
+ "status": "auth_error",
+ "message": error,
+ "connector_type": "google_calendar",
+ }
+ if "no events found" in error.lower():
+ return {
+ "status": "success",
+ "events": [],
+ "total": 0,
+ "message": error,
+ }
+ return {"status": "error", "message": error}
+
+ events = _format_calendar_events(events_raw)
+
+ return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
index 1dba36c20..b8561fee6 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py
@@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the update_calendar_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured update_calendar_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_calendar_event(
event_title_or_id: str,
@@ -74,312 +92,317 @@ def create_update_calendar_event_tool(
"""
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleCalendarToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, event_title_or_id
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"Event not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- if context.get("auth_expired"):
- logger.warning("Google Calendar account has expired authentication")
- return {
- "status": "auth_error",
- "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_calendar",
- }
-
- event = context["event"]
- event_id = event["event_id"]
- document_id = event.get("document_id")
- connector_id_from_context = context["account"]["id"]
-
- if not event_id:
- return {
- "status": "error",
- "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
- }
-
- logger.info(
- f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
- )
- result = request_approval(
- action_type="google_calendar_event_update",
- tool_name="update_calendar_event",
- params={
- "event_id": event_id,
- "document_id": document_id,
- "connector_id": connector_id_from_context,
- "new_summary": new_summary,
- "new_start_datetime": new_start_datetime,
- "new_end_datetime": new_end_datetime,
- "new_description": new_description,
- "new_location": new_location,
- "new_attendees": new_attendees,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
- }
-
- final_event_id = result.params.get("event_id", event_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_new_summary = result.params.get("new_summary", new_summary)
- final_new_start_datetime = result.params.get(
- "new_start_datetime", new_start_datetime
- )
- final_new_end_datetime = result.params.get(
- "new_end_datetime", new_end_datetime
- )
- final_new_description = result.params.get(
- "new_description", new_description
- )
- final_new_location = result.params.get("new_location", new_location)
- final_new_attendees = result.params.get("new_attendees", new_attendees)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this event.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _calendar_types = [
- SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_calendar_types),
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleCalendarToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, event_title_or_id
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Calendar connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"Event not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- logger.info(
- f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
- )
+ if context.get("auth_expired"):
+ logger.warning("Google Calendar account has expired authentication")
+ return {
+ "status": "auth_error",
+ "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_calendar",
+ }
- is_composio_calendar = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
- )
- if is_composio_calendar:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
+ event = context["event"]
+ event_id = event["event_id"]
+ document_id = event.get("document_id")
+ connector_id_from_context = context["account"]["id"]
+
+ if not event_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this connector.",
+ "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
- else:
- config_data = dict(connector.config)
- from app.config import config as app_config
- from app.utils.oauth_security import TokenEncryption
-
- token_encrypted = config_data.get("_token_encrypted", False)
- if token_encrypted and app_config.SECRET_KEY:
- token_encryption = TokenEncryption(app_config.SECRET_KEY)
- for key in ("token", "refresh_token", "client_secret"):
- if config_data.get(key):
- config_data[key] = token_encryption.decrypt_token(
- config_data[key]
- )
-
- exp = config_data.get("expiry", "")
- if exp:
- exp = exp.replace("Z", "")
-
- creds = Credentials(
- token=config_data.get("token"),
- refresh_token=config_data.get("refresh_token"),
- token_uri=config_data.get("token_uri"),
- client_id=config_data.get("client_id"),
- client_secret=config_data.get("client_secret"),
- scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp) if exp else None,
+ logger.info(
+ f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
+ )
+ result = request_approval(
+ action_type="google_calendar_event_update",
+ tool_name="update_calendar_event",
+ params={
+ "event_id": event_id,
+ "document_id": document_id,
+ "connector_id": connector_id_from_context,
+ "new_summary": new_summary,
+ "new_start_datetime": new_start_datetime,
+ "new_end_datetime": new_end_datetime,
+ "new_description": new_description,
+ "new_location": new_location,
+ "new_attendees": new_attendees,
+ },
+ context=context,
)
- update_body: dict[str, Any] = {}
- if final_new_summary is not None:
- update_body["summary"] = final_new_summary
- if final_new_start_datetime is not None:
- update_body["start"] = _build_time_body(
- final_new_start_datetime, context
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
+ }
+
+ final_event_id = result.params.get("event_id", event_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
)
- if final_new_end_datetime is not None:
- update_body["end"] = _build_time_body(final_new_end_datetime, context)
- if final_new_description is not None:
- update_body["description"] = final_new_description
- if final_new_location is not None:
- update_body["location"] = final_new_location
- if final_new_attendees is not None:
- update_body["attendees"] = [
- {"email": e.strip()} for e in final_new_attendees if e.strip()
+ final_new_summary = result.params.get("new_summary", new_summary)
+ final_new_start_datetime = result.params.get(
+ "new_start_datetime", new_start_datetime
+ )
+ final_new_end_datetime = result.params.get(
+ "new_end_datetime", new_end_datetime
+ )
+ final_new_description = result.params.get(
+ "new_description", new_description
+ )
+ final_new_location = result.params.get("new_location", new_location)
+ final_new_attendees = result.params.get("new_attendees", new_attendees)
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this event.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _calendar_types = [
+ SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
- if not update_body:
- return {
- "status": "error",
- "message": "No changes specified. Please provide at least one field to update.",
- }
-
- try:
- if is_composio_calendar:
- from app.services.composio_service import ComposioService
-
- composio_params: dict[str, Any] = {
- "calendar_id": "primary",
- "event_id": final_event_id,
- }
- if final_new_summary is not None:
- composio_params["summary"] = final_new_summary
- if final_new_start_datetime is not None:
- composio_params["start_time"] = final_new_start_datetime
- if final_new_end_datetime is not None:
- composio_params["end_time"] = final_new_end_datetime
- if final_new_description is not None:
- composio_params["description"] = final_new_description
- if final_new_location is not None:
- composio_params["location"] = final_new_location
- if final_new_attendees is not None:
- composio_params["attendees"] = [
- e.strip() for e in final_new_attendees if e.strip()
- ]
- if not _is_date_only(
- final_new_start_datetime or final_new_end_datetime or ""
- ):
- composio_params["timezone"] = context.get("timezone", "UTC")
-
- composio_result = await ComposioService().execute_tool(
- connected_account_id=cca_id,
- tool_name="GOOGLECALENDAR_PATCH_EVENT",
- params=composio_params,
- entity_id=f"surfsense_{user_id}",
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_calendar_types),
)
- if not composio_result.get("success"):
- raise RuntimeError(
- composio_result.get(
- "error", "Unknown Composio Calendar error"
- )
- )
- updated = composio_result.get("data", {})
- if isinstance(updated, dict):
- updated = updated.get("data", updated)
- if isinstance(updated, dict):
- updated = updated.get("response_data", updated)
- else:
- service = await asyncio.get_event_loop().run_in_executor(
- None, lambda: build("calendar", "v3", credentials=creds)
- )
- updated = await asyncio.get_event_loop().run_in_executor(
- None,
- lambda: (
- service.events()
- .patch(
- calendarId="primary",
- eventId=final_event_id,
- body=update_body,
- )
- .execute()
- ),
- )
- except Exception as api_err:
- from googleapiclient.errors import HttpError
-
- if isinstance(api_err, HttpError) and api_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
+ )
+ connector = result.scalars().first()
+ if not connector:
return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "error",
+ "message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
- raise
- logger.info(f"Calendar event updated: event_id={final_event_id}")
+ actual_connector_id = connector.id
- kb_message_suffix = ""
- if document_id is not None:
- try:
- from app.services.google_calendar import GoogleCalendarKBSyncService
+ logger.info(
+ f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
+ )
- kb_service = GoogleCalendarKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=document_id,
- event_id=final_event_id,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ is_composio_calendar = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
+ )
+ if is_composio_calendar:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this connector.",
+ }
+ else:
+ config_data = dict(connector.config)
+
+ from app.config import config as app_config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and app_config.SECRET_KEY:
+ token_encryption = TokenEncryption(app_config.SECRET_KEY)
+ for key in ("token", "refresh_token", "client_secret"):
+ if config_data.get(key):
+ config_data[key] = token_encryption.decrypt_token(
+ config_data[key]
+ )
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
+
+ creds = Credentials(
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes", []),
+ expiry=datetime.fromisoformat(exp) if exp else None,
)
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
- )
- else:
- kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after update failed: {kb_err}")
- kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
- return {
- "status": "success",
- "event_id": final_event_id,
- "html_link": updated.get("htmlLink"),
- "message": f"Successfully updated the calendar event.{kb_message_suffix}",
- }
+ update_body: dict[str, Any] = {}
+ if final_new_summary is not None:
+ update_body["summary"] = final_new_summary
+ if final_new_start_datetime is not None:
+ update_body["start"] = _build_time_body(
+ final_new_start_datetime, context
+ )
+ if final_new_end_datetime is not None:
+ update_body["end"] = _build_time_body(
+ final_new_end_datetime, context
+ )
+ if final_new_description is not None:
+ update_body["description"] = final_new_description
+ if final_new_location is not None:
+ update_body["location"] = final_new_location
+ if final_new_attendees is not None:
+ update_body["attendees"] = [
+ {"email": e.strip()} for e in final_new_attendees if e.strip()
+ ]
+
+ if not update_body:
+ return {
+ "status": "error",
+ "message": "No changes specified. Please provide at least one field to update.",
+ }
+
+ try:
+ if is_composio_calendar:
+ from app.services.composio_service import ComposioService
+
+ composio_params: dict[str, Any] = {
+ "calendar_id": "primary",
+ "event_id": final_event_id,
+ }
+ if final_new_summary is not None:
+ composio_params["summary"] = final_new_summary
+ if final_new_start_datetime is not None:
+ composio_params["start_time"] = final_new_start_datetime
+ if final_new_end_datetime is not None:
+ composio_params["end_time"] = final_new_end_datetime
+ if final_new_description is not None:
+ composio_params["description"] = final_new_description
+ if final_new_location is not None:
+ composio_params["location"] = final_new_location
+ if final_new_attendees is not None:
+ composio_params["attendees"] = [
+ e.strip() for e in final_new_attendees if e.strip()
+ ]
+ if not _is_date_only(
+ final_new_start_datetime or final_new_end_datetime or ""
+ ):
+ composio_params["timezone"] = context.get("timezone", "UTC")
+
+ composio_result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLECALENDAR_PATCH_EVENT",
+ params=composio_params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not composio_result.get("success"):
+ raise RuntimeError(
+ composio_result.get(
+ "error", "Unknown Composio Calendar error"
+ )
+ )
+ updated = composio_result.get("data", {})
+ if isinstance(updated, dict):
+ updated = updated.get("data", updated)
+ if isinstance(updated, dict):
+ updated = updated.get("response_data", updated)
+ else:
+ service = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: build("calendar", "v3", credentials=creds)
+ )
+ updated = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: (
+ service.events()
+ .patch(
+ calendarId="primary",
+ eventId=final_event_id,
+ body=update_body,
+ )
+ .execute()
+ ),
+ )
+ except Exception as api_err:
+ from googleapiclient.errors import HttpError
+
+ if isinstance(api_err, HttpError) and api_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(f"Calendar event updated: event_id={final_event_id}")
+
+ kb_message_suffix = ""
+ if document_id is not None:
+ try:
+ from app.services.google_calendar import (
+ GoogleCalendarKBSyncService,
+ )
+
+ kb_service = GoogleCalendarKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=document_id,
+ event_id=final_event_id,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after update failed: {kb_err}")
+ kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "event_id": final_event_id,
+ "html_link": updated.get("htmlLink"),
+ "message": f"Successfully updated the calendar event.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
index 2becec100..66199ca67 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
+from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
@@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_google_drive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Google Drive connector
+ user_id: User ID for fetching user-specific context
+
+ Returns:
+ Configured create_google_drive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_google_drive_file(
name: str,
@@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
f"create_google_drive_file called: name='{name}', type='{file_type}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
@@ -78,225 +98,232 @@ def create_create_google_drive_file_tool(
}
try:
- metadata_service = GoogleDriveToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Google Drive accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_drive",
- }
-
- logger.info(
- f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
- )
- result = request_approval(
- action_type="google_drive_file_creation",
- tool_name="create_google_drive_file",
- params={
- "name": name,
- "file_type": file_type,
- "content": content,
- "connector_id": None,
- "parent_folder_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
- }
-
- final_name = result.params.get("name", name)
- final_file_type = result.params.get("file_type", file_type)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get("connector_id")
- final_parent_folder_id = result.params.get("parent_folder_id")
-
- if not final_name or not final_name.strip():
- return {"status": "error", "message": "File name cannot be empty."}
-
- mime_type = _MIME_MAP.get(final_file_type)
- if not mime_type:
- return {
- "status": "error",
- "message": f"Unsupported file type '{final_file_type}'.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _drive_types = [
- SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
- ]
-
- if final_connector_id is not None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_drive_types),
- )
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleDriveToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Drive connector is invalid or has been disconnected.",
- }
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_drive_types),
+
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
- }
- actual_connector_id = connector.id
+ return {"status": "error", "message": context["error"]}
- logger.info(
- f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
- )
-
- is_composio_drive = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- )
- if is_composio_drive:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
- return {
- "status": "error",
- "message": "Composio connected account ID not found for this Drive connector.",
- }
- client = GoogleDriveClient(
- session=db_session,
- connector_id=actual_connector_id,
- )
- try:
- if is_composio_drive:
- from app.services.composio_service import ComposioService
-
- params: dict[str, Any] = {
- "name": final_name,
- "mimeType": mime_type,
- "fields": "id,name,webViewLink,mimeType",
- }
- if final_parent_folder_id:
- params["parents"] = [final_parent_folder_id]
- if final_content:
- params["description"] = final_content[:4096]
-
- result = await ComposioService().execute_tool(
- connected_account_id=cca_id,
- tool_name="GOOGLEDRIVE_CREATE_FILE",
- params=params,
- entity_id=f"surfsense_{user_id}",
- )
- if not result.get("success"):
- raise RuntimeError(
- result.get("error", "Unknown Composio Drive error")
- )
- created = result.get("data", {})
- if isinstance(created, dict):
- created = created.get("data", created)
- if isinstance(created, dict):
- created = created.get("response_data", created)
- if not isinstance(created, dict):
- created = {}
- else:
- created = await client.create_file(
- name=final_name,
- mime_type=mime_type,
- parent_folder_id=final_parent_folder_id,
- content=final_content,
- )
- except HttpError as http_err:
- if http_err.resp.status == 403:
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
- f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
+ "All Google Drive accounts have expired authentication"
)
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- _res = await db_session.execute(
- select(SearchSourceConnector).where(
- SearchSourceConnector.id == actual_connector_id
- )
- )
- _conn = _res.scalar_one_or_none()
- if _conn and not _conn.config.get("auth_expired"):
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- actual_connector_id,
- exc_info=True,
- )
return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "auth_error",
+ "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_drive",
}
- raise
- logger.info(
- f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.google_drive import GoogleDriveKBSyncService
-
- kb_service = GoogleDriveKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- file_id=created.get("id"),
- file_name=created.get("name", final_name),
- mime_type=mime_type,
- web_view_link=created.get("webViewLink"),
- content=final_content,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
+ logger.info(
+ f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
+ )
+ result = request_approval(
+ action_type="google_drive_file_creation",
+ tool_name="create_google_drive_file",
+ params={
+ "name": name,
+ "file_type": file_type,
+ "content": content,
+ "connector_id": None,
+ "parent_folder_id": None,
+ },
+ context=context,
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "file_id": created.get("id"),
- "name": created.get("name"),
- "web_view_link": created.get("webViewLink"),
- "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
+ }
+
+ final_name = result.params.get("name", name)
+ final_file_type = result.params.get("file_type", file_type)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get("connector_id")
+ final_parent_folder_id = result.params.get("parent_folder_id")
+
+ if not final_name or not final_name.strip():
+ return {"status": "error", "message": "File name cannot be empty."}
+
+ mime_type = _MIME_MAP.get(final_file_type)
+ if not mime_type:
+ return {
+ "status": "error",
+ "message": f"Unsupported file type '{final_file_type}'.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _drive_types = [
+ SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
+ ]
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_drive_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Drive connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
+ else:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_drive_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
+ }
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
+ )
+
+ is_composio_drive = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
+ )
+ if is_composio_drive:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Drive connector.",
+ }
+ client = GoogleDriveClient(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+ try:
+ if is_composio_drive:
+ from app.services.composio_service import ComposioService
+
+ params: dict[str, Any] = {
+ "name": final_name,
+ "mimeType": mime_type,
+ "fields": "id,name,webViewLink,mimeType",
+ }
+ if final_parent_folder_id:
+ params["parents"] = [final_parent_folder_id]
+ if final_content:
+ params["description"] = final_content[:4096]
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLEDRIVE_CREATE_FILE",
+ params=params,
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ raise RuntimeError(
+ result.get("error", "Unknown Composio Drive error")
+ )
+ created = result.get("data", {})
+ if isinstance(created, dict):
+ created = created.get("data", created)
+ if isinstance(created, dict):
+ created = created.get("response_data", created)
+ if not isinstance(created, dict):
+ created = {}
+ else:
+ created = await client.create_file(
+ name=final_name,
+ mime_type=mime_type,
+ parent_folder_id=final_parent_folder_id,
+ content=final_content,
+ )
+ except HttpError as http_err:
+ if http_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ _res = await db_session.execute(
+ select(SearchSourceConnector).where(
+ SearchSourceConnector.id == actual_connector_id
+ )
+ )
+ _conn = _res.scalar_one_or_none()
+ if _conn and not _conn.config.get("auth_expired"):
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ actual_connector_id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(
+ f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.google_drive import GoogleDriveKBSyncService
+
+ kb_service = GoogleDriveKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ file_id=created.get("id"),
+ file_name=created.get("name", final_name),
+ mime_type=mime_type,
+ web_view_link=created.get("webViewLink"),
+ content=final_content,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "file_id": created.get("id"),
+ "name": created.get("name"),
+ "web_view_link": created.get("webViewLink"),
+ "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
index 3c404527e..b3c9240d8 100644
--- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py
@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
+from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_google_drive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Google Drive connector
+ user_id: User ID for fetching user-specific context
+
+ Returns:
+ Configured delete_google_drive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_google_drive_file(
file_name: str,
@@ -55,211 +75,214 @@ def create_delete_google_drive_file_tool(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
try:
- metadata_service = GoogleDriveToolMetadataService(db_session)
- context = await metadata_service.get_trash_context(
- search_space_id, user_id, file_name
- )
-
- if "error" in context:
- error_msg = context["error"]
- if "not found" in error_msg.lower():
- logger.warning(f"File not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- logger.error(f"Failed to fetch trash context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Google Drive account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ metadata_service = GoogleDriveToolMetadataService(db_session)
+ context = await metadata_service.get_trash_context(
+ search_space_id, user_id, file_name
)
- return {
- "status": "auth_error",
- "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "google_drive",
- }
- file = context["file"]
- file_id = file["file_id"]
- document_id = file.get("document_id")
- connector_id_from_context = context["account"]["id"]
+ if "error" in context:
+ error_msg = context["error"]
+ if "not found" in error_msg.lower():
+ logger.warning(f"File not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ logger.error(f"Failed to fetch trash context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if not file_id:
- return {
- "status": "error",
- "message": "File ID is missing from the indexed document. Please re-index the file and try again.",
- }
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Google Drive account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "google_drive",
+ }
- logger.info(
- f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="google_drive_file_trash",
- tool_name="delete_google_drive_file",
- params={
- "file_id": file_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ file = context["file"]
+ file_id = file["file_id"]
+ document_id = file.get("document_id")
+ connector_id_from_context = context["account"]["id"]
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
- }
-
- final_file_id = result.params.get("file_id", file_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this file.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- _drive_types = [
- SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
- SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
- ]
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type.in_(_drive_types),
- )
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Google Drive connector is invalid or has been disconnected.",
- }
-
- logger.info(
- f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
- )
-
- is_composio_drive = (
- connector.connector_type
- == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
- )
- if is_composio_drive:
- cca_id = connector.config.get("composio_connected_account_id")
- if not cca_id:
+ if not file_id:
return {
"status": "error",
- "message": "Composio connected account ID not found for this Drive connector.",
+ "message": "File ID is missing from the indexed document. Please re-index the file and try again.",
}
- client = GoogleDriveClient(
- session=db_session,
- connector_id=connector.id,
- )
- try:
- if is_composio_drive:
- from app.services.composio_service import ComposioService
-
- result = await ComposioService().execute_tool(
- connected_account_id=cca_id,
- tool_name="GOOGLEDRIVE_TRASH_FILE",
- params={"file_id": final_file_id},
- entity_id=f"surfsense_{user_id}",
- )
- if not result.get("success"):
- raise RuntimeError(
- result.get("error", "Unknown Composio Drive error")
- )
- else:
- await client.trash_file(file_id=final_file_id)
- except HttpError as http_err:
- if http_err.resp.status == 403:
- logger.warning(
- f"Insufficient permissions for connector {connector.id}: {http_err}"
- )
- try:
- from sqlalchemy.orm.attributes import flag_modified
-
- if not connector.config.get("auth_expired"):
- connector.config = {
- **connector.config,
- "auth_expired": True,
- }
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- logger.warning(
- "Failed to persist auth_expired for connector %s",
- connector.id,
- exc_info=True,
- )
- return {
- "status": "insufficient_permissions",
- "connector_id": connector.id,
- "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
-
- logger.info(
- f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
- )
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "file_id": final_file_id,
- "message": f"Successfully moved '{file['name']}' to trash.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
-
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"File moved to trash, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ logger.info(
+ f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="google_drive_file_trash",
+ tool_name="delete_google_drive_file",
+ params={
+ "file_id": file_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- return trash_result
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
+ }
+
+ final_file_id = result.params.get("file_id", file_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this file.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ _drive_types = [
+ SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
+ SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
+ ]
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type.in_(_drive_types),
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Google Drive connector is invalid or has been disconnected.",
+ }
+
+ logger.info(
+ f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
+ )
+
+ is_composio_drive = (
+ connector.connector_type
+ == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
+ )
+ if is_composio_drive:
+ cca_id = connector.config.get("composio_connected_account_id")
+ if not cca_id:
+ return {
+ "status": "error",
+ "message": "Composio connected account ID not found for this Drive connector.",
+ }
+
+ client = GoogleDriveClient(
+ session=db_session,
+ connector_id=connector.id,
+ )
+ try:
+ if is_composio_drive:
+ from app.services.composio_service import ComposioService
+
+ result = await ComposioService().execute_tool(
+ connected_account_id=cca_id,
+ tool_name="GOOGLEDRIVE_TRASH_FILE",
+ params={"file_id": final_file_id},
+ entity_id=f"surfsense_{user_id}",
+ )
+ if not result.get("success"):
+ raise RuntimeError(
+ result.get("error", "Unknown Composio Drive error")
+ )
+ else:
+ await client.trash_file(file_id=final_file_id)
+ except HttpError as http_err:
+ if http_err.resp.status == 403:
+ logger.warning(
+ f"Insufficient permissions for connector {connector.id}: {http_err}"
+ )
+ try:
+ from sqlalchemy.orm.attributes import flag_modified
+
+ if not connector.config.get("auth_expired"):
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ logger.warning(
+ "Failed to persist auth_expired for connector %s",
+ connector.id,
+ exc_info=True,
+ )
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": connector.id,
+ "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ logger.info(
+ f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
+ )
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "file_id": final_file_id,
+ "message": f"Successfully moved '{file['name']}' to trash.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"File moved to trash, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
index 8b40dde65..0b04f1642 100644
--- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
@@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
+from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """Factory function to create the create_jira_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits. Per-call sessions also
+ keep the request's outer transaction free of long-running Jira API
+ blocking.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Jira connector
+ user_id: User ID for fetching user-specific context
+ connector_id: Optional specific connector ID (if known)
+
+ Returns:
+ Configured create_jira_issue tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_jira_issue(
project_key: str,
@@ -49,158 +72,167 @@ def create_create_jira_issue_tool(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
- metadata_service = JiraToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- return {"status": "error", "message": context["error"]}
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected Jira accounts need re-authentication.",
- "connector_type": "jira",
- }
-
- result = request_approval(
- action_type="jira_issue_creation",
- tool_name="create_jira_issue",
- params={
- "project_key": project_key,
- "summary": summary,
- "issue_type": issue_type,
- "description": description,
- "priority": priority,
- "connector_id": connector_id,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_project_key = result.params.get("project_key", project_key)
- final_summary = result.params.get("summary", summary)
- final_issue_type = result.params.get("issue_type", issue_type)
- final_description = result.params.get("description", description)
- final_priority = result.params.get("priority", priority)
- final_connector_id = result.params.get("connector_id", connector_id)
-
- if not final_summary or not final_summary.strip():
- return {"status": "error", "message": "Issue summary cannot be empty."}
- if not final_project_key:
- return {"status": "error", "message": "A project must be selected."}
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = JiraToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- return {"status": "error", "message": "No Jira connector found."}
- actual_connector_id = connector.id
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
- )
+
+ if "error" in context:
+ return {"status": "error", "message": context["error"]}
+
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected Jira accounts need re-authentication.",
+ "connector_type": "jira",
+ }
+
+ result = request_approval(
+ action_type="jira_issue_creation",
+ tool_name="create_jira_issue",
+ params={
+ "project_key": project_key,
+ "summary": summary,
+ "issue_type": issue_type,
+ "description": description,
+ "priority": priority,
+ "connector_id": connector_id,
+ },
+ context=context,
)
- connector = result.scalars().first()
- if not connector:
+
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_project_key = result.params.get("project_key", project_key)
+ final_summary = result.params.get("summary", summary)
+ final_issue_type = result.params.get("issue_type", issue_type)
+ final_description = result.params.get("description", description)
+ final_priority = result.params.get("priority", priority)
+ final_connector_id = result.params.get("connector_id", connector_id)
+
+ if not final_summary or not final_summary.strip():
return {
"status": "error",
- "message": "Selected Jira connector is invalid.",
+ "message": "Issue summary cannot be empty.",
}
+ if not final_project_key:
+ return {"status": "error", "message": "A project must be selected."}
- try:
- jira_history = JiraHistoryConnector(
- session=db_session, connector_id=actual_connector_id
- )
- jira_client = await jira_history._get_jira_client()
- api_result = await asyncio.to_thread(
- jira_client.create_issue,
- project_key=final_project_key,
- summary=final_summary,
- issue_type=final_issue_type,
- description=final_description,
- priority=final_priority,
- )
- except Exception as api_err:
- if "status code 403" in str(api_err).lower():
- try:
- _conn = connector
- _conn.config = {**_conn.config, "auth_expired": True}
- flag_modified(_conn, "config")
- await db_session.commit()
- except Exception:
- pass
- return {
- "status": "insufficient_permissions",
- "connector_id": actual_connector_id,
- "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
- }
- raise
+ from sqlalchemy.future import select
- issue_key = api_result.get("key", "")
- issue_url = (
- f"{jira_history._base_url}/browse/{issue_key}"
- if jira_history._base_url and issue_key
- else ""
- )
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- kb_message_suffix = ""
- try:
- from app.services.jira import JiraKBSyncService
-
- kb_service = JiraKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- issue_id=issue_key,
- issue_identifier=issue_key,
- issue_title=final_summary,
- description=final_description,
- state="To Do",
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Jira connector found.",
+ }
+ actual_connector_id = connector.id
else:
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Jira connector is invalid.",
+ }
- return {
- "status": "success",
- "issue_key": issue_key,
- "issue_url": issue_url,
- "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
- }
+ try:
+ jira_history = JiraHistoryConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
+ jira_client = await jira_history._get_jira_client()
+ api_result = await asyncio.to_thread(
+ jira_client.create_issue,
+ project_key=final_project_key,
+ summary=final_summary,
+ issue_type=final_issue_type,
+ description=final_description,
+ priority=final_priority,
+ )
+ except Exception as api_err:
+ if "status code 403" in str(api_err).lower():
+ try:
+ _conn = connector
+ _conn.config = {**_conn.config, "auth_expired": True}
+ flag_modified(_conn, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": actual_connector_id,
+ "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ issue_key = api_result.get("key", "")
+ issue_url = (
+ f"{jira_history._base_url}/browse/{issue_key}"
+ if jira_history._base_url and issue_key
+ else ""
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.jira import JiraKBSyncService
+
+ kb_service = JiraKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ issue_id=issue_key,
+ issue_identifier=issue_key,
+ issue_title=final_summary,
+ description=final_description,
+ state="To Do",
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "issue_key": issue_key,
+ "issue_url": issue_url,
+ "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
index 6466c80ea..c41aedad9 100644
--- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
@@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
+from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """Factory function to create the delete_jira_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Jira connector
+ user_id: User ID for fetching user-specific context
+ connector_id: Optional specific connector ID (if known)
+
+ Returns:
+ Configured delete_jira_issue tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_jira_issue(
issue_title_or_key: str,
@@ -44,130 +65,136 @@ def create_delete_jira_issue_tool(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
- metadata_service = JiraToolMetadataService(db_session)
- context = await metadata_service.get_deletion_context(
- search_space_id, user_id, issue_title_or_key
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "jira",
- }
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
-
- issue_data = context["issue"]
- issue_key = issue_data["issue_id"]
- document_id = issue_data["document_id"]
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="jira_issue_deletion",
- tool_name="delete_jira_issue",
- params={
- "issue_key": issue_key,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_key = result.params.get("issue_key", issue_key)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = JiraToolMetadataService(db_session)
+ context = await metadata_service.get_deletion_context(
+ search_space_id, user_id, issue_title_or_key
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Jira connector is invalid.",
- }
- try:
- jira_history = JiraHistoryConnector(
- session=db_session, connector_id=final_connector_id
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "jira",
+ }
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
+
+ issue_data = context["issue"]
+ issue_key = issue_data["issue_id"]
+ document_id = issue_data["document_id"]
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ result = request_approval(
+ action_type="jira_issue_deletion",
+ tool_name="delete_jira_issue",
+ params={
+ "issue_key": issue_key,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- jira_client = await jira_history._get_jira_client()
- await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
- except Exception as api_err:
- if "status code 403" in str(api_err).lower():
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- raise
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- from app.db import Document
+ final_issue_key = result.params.get("issue_key", issue_key)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this issue.",
+ }
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
)
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Jira connector is invalid.",
+ }
- message = f"Jira issue {final_issue_key} deleted successfully."
- if deleted_from_kb:
- message += " Also removed from the knowledge base."
+ try:
+ jira_history = JiraHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ jira_client = await jira_history._get_jira_client()
+ await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
+ except Exception as api_err:
+ if "status code 403" in str(api_err).lower():
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
- return {
- "status": "success",
- "issue_key": final_issue_key,
- "deleted_from_kb": deleted_from_kb,
- "message": message,
- }
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ from app.db import Document
+
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+
+ message = f"Jira issue {final_issue_key} deleted successfully."
+ if deleted_from_kb:
+ message += " Also removed from the knowledge base."
+
+ return {
+ "status": "success",
+ "issue_key": final_issue_key,
+ "deleted_from_kb": deleted_from_kb,
+ "message": message,
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
index f6e586a2e..0fd7b28b3 100644
--- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
@@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
+from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
+ """Factory function to create the update_jira_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ search_space_id: Search space ID to find the Jira connector
+ user_id: User ID for fetching user-specific context
+ connector_id: Optional specific connector ID (if known)
+
+ Returns:
+ Configured update_jira_issue tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_jira_issue(
issue_title_or_key: str,
@@ -48,169 +69,177 @@ def create_update_jira_issue_tool(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
- metadata_service = JiraToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, issue_title_or_key
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "jira",
- }
- if "not found" in error_msg.lower():
- return {"status": "not_found", "message": error_msg}
- return {"status": "error", "message": error_msg}
-
- issue_data = context["issue"]
- issue_key = issue_data["issue_id"]
- document_id = issue_data.get("document_id")
- connector_id_from_context = context.get("account", {}).get("id")
-
- result = request_approval(
- action_type="jira_issue_update",
- tool_name="update_jira_issue",
- params={
- "issue_key": issue_key,
- "document_id": document_id,
- "new_summary": new_summary,
- "new_description": new_description,
- "new_priority": new_priority,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_key = result.params.get("issue_key", issue_key)
- final_summary = result.params.get("new_summary", new_summary)
- final_description = result.params.get("new_description", new_description)
- final_priority = result.params.get("new_priority", new_priority)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_document_id = result.params.get("document_id", document_id)
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if not final_connector_id:
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.JIRA_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = JiraToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, issue_title_or_key
)
- )
- connector = result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "Selected Jira connector is invalid.",
- }
- fields: dict[str, Any] = {}
- if final_summary:
- fields["summary"] = final_summary
- if final_description is not None:
- fields["description"] = {
- "type": "doc",
- "version": 1,
- "content": [
- {
- "type": "paragraph",
- "content": [{"type": "text", "text": final_description}],
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "jira",
}
- ],
- }
- if final_priority:
- fields["priority"] = {"name": final_priority}
+ if "not found" in error_msg.lower():
+ return {"status": "not_found", "message": error_msg}
+ return {"status": "error", "message": error_msg}
- if not fields:
- return {"status": "error", "message": "No changes specified."}
+ issue_data = context["issue"]
+ issue_key = issue_data["issue_id"]
+ document_id = issue_data.get("document_id")
+ connector_id_from_context = context.get("account", {}).get("id")
- try:
- jira_history = JiraHistoryConnector(
- session=db_session, connector_id=final_connector_id
+ result = request_approval(
+ action_type="jira_issue_update",
+ tool_name="update_jira_issue",
+ params={
+ "issue_key": issue_key,
+ "document_id": document_id,
+ "new_summary": new_summary,
+ "new_description": new_description,
+ "new_priority": new_priority,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
)
- jira_client = await jira_history._get_jira_client()
- await asyncio.to_thread(
- jira_client.update_issue, final_issue_key, fields
- )
- except Exception as api_err:
- if "status code 403" in str(api_err).lower():
- try:
- connector.config = {**connector.config, "auth_expired": True}
- flag_modified(connector, "config")
- await db_session.commit()
- except Exception:
- pass
+
+ if result.rejected:
return {
- "status": "insufficient_permissions",
- "connector_id": final_connector_id,
- "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
}
- raise
- issue_url = (
- f"{jira_history._base_url}/browse/{final_issue_key}"
- if jira_history._base_url and final_issue_key
- else ""
- )
+ final_issue_key = result.params.get("issue_key", issue_key)
+ final_summary = result.params.get("new_summary", new_summary)
+ final_description = result.params.get(
+ "new_description", new_description
+ )
+ final_priority = result.params.get("new_priority", new_priority)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_document_id = result.params.get("document_id", document_id)
- kb_message_suffix = ""
- if final_document_id:
- try:
- from app.services.jira import JiraKBSyncService
+ from sqlalchemy.future import select
- kb_service = JiraKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=final_document_id,
- issue_id=final_issue_key,
- user_id=user_id,
- search_space_id=search_space_id,
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if not final_connector_id:
+ return {
+ "status": "error",
+ "message": "No connector found for this issue.",
+ }
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.JIRA_CONNECTOR,
)
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Jira connector is invalid.",
+ }
+
+ fields: dict[str, Any] = {}
+ if final_summary:
+ fields["summary"] = final_summary
+ if final_description is not None:
+ fields["description"] = {
+ "type": "doc",
+ "version": 1,
+ "content": [
+ {
+ "type": "paragraph",
+ "content": [
+ {"type": "text", "text": final_description}
+ ],
+ }
+ ],
+ }
+ if final_priority:
+ fields["priority"] = {"name": final_priority}
+
+ if not fields:
+ return {"status": "error", "message": "No changes specified."}
+
+ try:
+ jira_history = JiraHistoryConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ jira_client = await jira_history._get_jira_client()
+ await asyncio.to_thread(
+ jira_client.update_issue, final_issue_key, fields
+ )
+ except Exception as api_err:
+ if "status code 403" in str(api_err).lower():
+ try:
+ connector.config = {
+ **connector.config,
+ "auth_expired": True,
+ }
+ flag_modified(connector, "config")
+ await db_session.commit()
+ except Exception:
+ pass
+ return {
+ "status": "insufficient_permissions",
+ "connector_id": final_connector_id,
+ "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
+ }
+ raise
+
+ issue_url = (
+ f"{jira_history._base_url}/browse/{final_issue_key}"
+ if jira_history._base_url and final_issue_key
+ else ""
+ )
+
+ kb_message_suffix = ""
+ if final_document_id:
+ try:
+ from app.services.jira import JiraKBSyncService
+
+ kb_service = JiraKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=final_document_id,
+ issue_id=final_issue_key,
+ user_id=user_id,
+ search_space_id=search_space_id,
)
- else:
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = (
+ " The knowledge base will be updated in the next sync."
+ )
+ except Exception as kb_err:
+ logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
- except Exception as kb_err:
- logger.warning(f"KB sync after update failed: {kb_err}")
- kb_message_suffix = (
- " The knowledge base will be updated in the next sync."
- )
- return {
- "status": "success",
- "issue_key": final_issue_key,
- "issue_url": issue_url,
- "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
- }
+ return {
+ "status": "success",
+ "issue_key": final_issue_key,
+ "issue_url": issue_url,
+ "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py
index ff254e133..f897bee7a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
+from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
- """
- Factory function to create the create_linear_issue tool.
+ """Factory function to create the create_linear_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
Args:
- db_session: Database session for accessing the Linear connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
Returns:
Configured create_linear_issue tool
"""
+ del db_session # per-call session — see docstring
@tool
async def create_linear_issue(
@@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
"""
logger.info(f"create_linear_issue called: title='{title}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@@ -75,160 +83,170 @@ def create_create_linear_issue_tool(
}
try:
- metadata_service = LinearToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {"status": "error", "message": context["error"]}
-
- workspaces = context.get("workspaces", [])
- if workspaces and all(w.get("auth_expired") for w in workspaces):
- logger.warning("All Linear accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "linear",
- }
-
- logger.info(f"Requesting approval for creating Linear issue: '{title}'")
- result = request_approval(
- action_type="linear_issue_creation",
- tool_name="create_linear_issue",
- params={
- "title": title,
- "description": description,
- "team_id": None,
- "state_id": None,
- "assignee_id": None,
- "priority": None,
- "label_ids": [],
- "connector_id": connector_id,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Linear issue creation rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_title = result.params.get("title", title)
- final_description = result.params.get("description", description)
- final_team_id = result.params.get("team_id")
- final_state_id = result.params.get("state_id")
- final_assignee_id = result.params.get("assignee_id")
- final_priority = result.params.get("priority")
- final_label_ids = result.params.get("label_ids") or []
- final_connector_id = result.params.get("connector_id", connector_id)
-
- if not final_title or not final_title.strip():
- logger.error("Title is empty or contains only whitespace")
- return {"status": "error", "message": "Issue title cannot be empty."}
- if not final_team_id:
- return {
- "status": "error",
- "message": "A team must be selected to create an issue.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = LinearToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
+
+ if "error" in context:
+ logger.error(
+ f"Failed to fetch creation context: {context['error']}"
+ )
+ return {"status": "error", "message": context["error"]}
+
+ workspaces = context.get("workspaces", [])
+ if workspaces and all(w.get("auth_expired") for w in workspaces):
+ logger.warning("All Linear accounts have expired authentication")
+ return {
+ "status": "auth_error",
+ "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "linear",
+ }
+
+ logger.info(f"Requesting approval for creating Linear issue: '{title}'")
+ result = request_approval(
+ action_type="linear_issue_creation",
+ tool_name="create_linear_issue",
+ params={
+ "title": title,
+ "description": description,
+ "team_id": None,
+ "state_id": None,
+ "assignee_id": None,
+ "priority": None,
+ "label_ids": [],
+ "connector_id": connector_id,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Linear issue creation rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_title = result.params.get("title", title)
+ final_description = result.params.get("description", description)
+ final_team_id = result.params.get("team_id")
+ final_state_id = result.params.get("state_id")
+ final_assignee_id = result.params.get("assignee_id")
+ final_priority = result.params.get("priority")
+ final_label_ids = result.params.get("label_ids") or []
+ final_connector_id = result.params.get("connector_id", connector_id)
+
+ if not final_title or not final_title.strip():
+ logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
- "message": "No Linear connector found. Please connect Linear in your workspace settings.",
+ "message": "Issue title cannot be empty.",
}
- actual_connector_id = connector.id
- logger.info(f"Found Linear connector: id={actual_connector_id}")
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
- )
- )
- connector = result.scalars().first()
- if not connector:
+ if not final_team_id:
return {
"status": "error",
- "message": "Selected Linear connector is invalid or has been disconnected.",
+ "message": "A team must be selected to create an issue.",
}
- logger.info(f"Validated Linear connector: id={actual_connector_id}")
- logger.info(
- f"Creating Linear issue with final params: title='{final_title}'"
- )
- linear_client = LinearConnector(
- session=db_session, connector_id=actual_connector_id
- )
- result = await linear_client.create_issue(
- team_id=final_team_id,
- title=final_title,
- description=final_description,
- state_id=final_state_id,
- assignee_id=final_assignee_id,
- priority=final_priority,
- label_ids=final_label_ids if final_label_ids else None,
- )
+ from sqlalchemy.future import select
- if result.get("status") == "error":
- logger.error(f"Failed to create Linear issue: {result.get('message')}")
- return {"status": "error", "message": result.get("message")}
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
- logger.info(
- f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.linear import LinearKBSyncService
-
- kb_service = LinearKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- issue_id=result.get("id"),
- issue_identifier=result.get("identifier", ""),
- issue_title=result.get("title", final_title),
- issue_url=result.get("url"),
- description=final_description,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "No Linear connector found. Please connect Linear in your workspace settings.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Found Linear connector: id={actual_connector_id}")
else:
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected Linear connector is invalid or has been disconnected.",
+ }
+ logger.info(f"Validated Linear connector: id={actual_connector_id}")
- return {
- "status": "success",
- "issue_id": result.get("id"),
- "identifier": result.get("identifier"),
- "url": result.get("url"),
- "message": (result.get("message", "") + kb_message_suffix),
- }
+ logger.info(
+ f"Creating Linear issue with final params: title='{final_title}'"
+ )
+ linear_client = LinearConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
+ result = await linear_client.create_issue(
+ team_id=final_team_id,
+ title=final_title,
+ description=final_description,
+ state_id=final_state_id,
+ assignee_id=final_assignee_id,
+ priority=final_priority,
+ label_ids=final_label_ids if final_label_ids else None,
+ )
+
+ if result.get("status") == "error":
+ logger.error(
+ f"Failed to create Linear issue: {result.get('message')}"
+ )
+ return {"status": "error", "message": result.get("message")}
+
+ logger.info(
+ f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.linear import LinearKBSyncService
+
+ kb_service = LinearKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ issue_id=result.get("id"),
+ issue_identifier=result.get("identifier", ""),
+ issue_title=result.get("title", final_title),
+ issue_url=result.get("url"),
+ description=final_description,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "issue_id": result.get("id"),
+ "identifier": result.get("identifier"),
+ "url": result.get("url"),
+ "message": (result.get("message", "") + kb_message_suffix),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py
index 29ef0cdf2..c5039a8eb 100644
--- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
+from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
- """
- Factory function to create the delete_linear_issue tool.
+ """Factory function to create the delete_linear_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
Args:
- db_session: Database session for accessing the Linear connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for finding the correct Linear connector
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
Returns:
Configured delete_linear_issue tool
"""
+ del db_session # per-call session — see docstring
@tool
async def delete_linear_issue(
@@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@@ -83,149 +91,152 @@ def create_delete_linear_issue_tool(
}
try:
- metadata_service = LinearToolMetadataService(db_session)
- context = await metadata_service.get_delete_context(
- search_space_id, user_id, issue_ref
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- logger.warning(f"Auth expired for delete context: {error_msg}")
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "linear",
- }
- if "not found" in error_msg.lower():
- logger.warning(f"Issue not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- else:
- logger.error(f"Failed to fetch delete context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- issue_id = context["issue"]["id"]
- issue_identifier = context["issue"].get("identifier", "")
- document_id = context["issue"]["document_id"]
- connector_id_from_context = context.get("workspace", {}).get("id")
-
- logger.info(
- f"Requesting approval for deleting Linear issue: '{issue_ref}' "
- f"(id={issue_id}, delete_from_kb={delete_from_kb})"
- )
- result = request_approval(
- action_type="linear_issue_deletion",
- tool_name="delete_linear_issue",
- params={
- "issue_id": issue_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Linear issue deletion rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_id = result.params.get("issue_id", issue_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- logger.info(
- f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
- f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if final_connector_id:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = LinearToolMetadataService(db_session)
+ context = await metadata_service.get_delete_context(
+ search_space_id, user_id, issue_ref
)
- connector = result.scalars().first()
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ logger.warning(f"Auth expired for delete context: {error_msg}")
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "linear",
+ }
+ if "not found" in error_msg.lower():
+ logger.warning(f"Issue not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ else:
+ logger.error(f"Failed to fetch delete context: {error_msg}")
+ return {"status": "error", "message": error_msg}
+
+ issue_id = context["issue"]["id"]
+ issue_identifier = context["issue"].get("identifier", "")
+ document_id = context["issue"]["document_id"]
+ connector_id_from_context = context.get("workspace", {}).get("id")
+
+ logger.info(
+ f"Requesting approval for deleting Linear issue: '{issue_ref}' "
+ f"(id={issue_id}, delete_from_kb={delete_from_kb})"
+ )
+ result = request_approval(
+ action_type="linear_issue_deletion",
+ tool_name="delete_linear_issue",
+ params={
+ "issue_id": issue_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Linear issue deletion rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_issue_id = result.params.get("issue_id", issue_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ logger.info(
+ f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
+ f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if final_connector_id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
)
+ connector = result.scalars().first()
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Linear connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Validated Linear connector: id={actual_connector_id}")
+ else:
+ logger.error("No connector found for this issue")
return {
"status": "error",
- "message": "Selected Linear connector is invalid or has been disconnected.",
+ "message": "No connector found for this issue.",
}
- actual_connector_id = connector.id
- logger.info(f"Validated Linear connector: id={actual_connector_id}")
- else:
- logger.error("No connector found for this issue")
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
- linear_client = LinearConnector(
- session=db_session, connector_id=actual_connector_id
- )
+ linear_client = LinearConnector(
+ session=db_session, connector_id=actual_connector_id
+ )
- result = await linear_client.archive_issue(issue_id=final_issue_id)
+ result = await linear_client.archive_issue(issue_id=final_issue_id)
- logger.info(
- f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
- )
+ logger.info(
+ f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
+ )
- deleted_from_kb = False
- if (
- result.get("status") == "success"
- and final_delete_from_kb
- and document_id
- ):
- try:
- from app.db import Document
+ deleted_from_kb = False
+ if (
+ result.get("status") == "success"
+ and final_delete_from_kb
+ and document_id
+ ):
+ try:
+ from app.db import Document
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ result["warning"] = (
+ f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
)
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- result["warning"] = (
- f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
- )
- if result.get("status") == "success":
- result["deleted_from_kb"] = deleted_from_kb
- if issue_identifier:
- result["message"] = (
- f"Issue {issue_identifier} archived successfully."
- )
- if deleted_from_kb:
- result["message"] = (
- f"{result.get('message', '')} Also removed from the knowledge base."
- )
+ if result.get("status") == "success":
+ result["deleted_from_kb"] = deleted_from_kb
+ if issue_identifier:
+ result["message"] = (
+ f"Issue {issue_identifier} archived successfully."
+ )
+ if deleted_from_kb:
+ result["message"] = (
+ f"{result.get('message', '')} Also removed from the knowledge base."
+ )
- return result
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py
index f35d0dddd..d610ce2b7 100644
--- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py
+++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
+from app.db import async_session_maker
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
logger = logging.getLogger(__name__)
@@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
- """
- Factory function to create the update_linear_issue tool.
+ """Factory function to create the update_linear_issue tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits.
Args:
- db_session: Database session for accessing the Linear connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
Returns:
Configured update_linear_issue tool
"""
+ del db_session # per-call session — see docstring
@tool
async def update_linear_issue(
@@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
"""
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@@ -96,176 +104,177 @@ def create_update_linear_issue_tool(
}
try:
- metadata_service = LinearToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, issue_ref
- )
-
- if "error" in context:
- error_msg = context["error"]
- if context.get("auth_expired"):
- logger.warning(f"Auth expired for update context: {error_msg}")
- return {
- "status": "auth_error",
- "message": error_msg,
- "connector_id": context.get("connector_id"),
- "connector_type": "linear",
- }
- if "not found" in error_msg.lower():
- logger.warning(f"Issue not found: {error_msg}")
- return {"status": "not_found", "message": error_msg}
- else:
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {"status": "error", "message": error_msg}
-
- issue_id = context["issue"]["id"]
- document_id = context["issue"]["document_id"]
- connector_id_from_context = context.get("workspace", {}).get("id")
-
- team = context.get("team", {})
- new_state_id = _resolve_state(team, new_state_name)
- new_assignee_id = _resolve_assignee(team, new_assignee_email)
- new_label_ids = _resolve_labels(team, new_label_names)
-
- logger.info(
- f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
- )
- result = request_approval(
- action_type="linear_issue_update",
- tool_name="update_linear_issue",
- params={
- "issue_id": issue_id,
- "document_id": document_id,
- "new_title": new_title,
- "new_description": new_description,
- "new_state_id": new_state_id,
- "new_assignee_id": new_assignee_id,
- "new_priority": new_priority,
- "new_label_ids": new_label_ids,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Linear issue update rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_issue_id = result.params.get("issue_id", issue_id)
- final_document_id = result.params.get("document_id", document_id)
- final_new_title = result.params.get("new_title", new_title)
- final_new_description = result.params.get(
- "new_description", new_description
- )
- final_new_state_id = result.params.get("new_state_id", new_state_id)
- final_new_assignee_id = result.params.get(
- "new_assignee_id", new_assignee_id
- )
- final_new_priority = result.params.get("new_priority", new_priority)
- final_new_label_ids: list[str] | None = result.params.get(
- "new_label_ids", new_label_ids
- )
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
-
- if not final_connector_id:
- logger.error("No connector found for this issue")
- return {
- "status": "error",
- "message": "No connector found for this issue.",
- }
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ async with async_session_maker() as db_session:
+ metadata_service = LinearToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, issue_ref
)
- )
- connector = result.scalars().first()
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "Selected Linear connector is invalid or has been disconnected.",
- }
- logger.info(f"Validated Linear connector: id={final_connector_id}")
- logger.info(
- f"Updating Linear issue with final params: issue_id={final_issue_id}"
- )
- linear_client = LinearConnector(
- session=db_session, connector_id=final_connector_id
- )
- updated_issue = await linear_client.update_issue(
- issue_id=final_issue_id,
- title=final_new_title,
- description=final_new_description,
- state_id=final_new_state_id,
- assignee_id=final_new_assignee_id,
- priority=final_new_priority,
- label_ids=final_new_label_ids,
- )
+ if "error" in context:
+ error_msg = context["error"]
+ if context.get("auth_expired"):
+ logger.warning(f"Auth expired for update context: {error_msg}")
+ return {
+ "status": "auth_error",
+ "message": error_msg,
+ "connector_id": context.get("connector_id"),
+ "connector_type": "linear",
+ }
+ if "not found" in error_msg.lower():
+ logger.warning(f"Issue not found: {error_msg}")
+ return {"status": "not_found", "message": error_msg}
+ else:
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {"status": "error", "message": error_msg}
- if updated_issue.get("status") == "error":
- logger.error(
- f"Failed to update Linear issue: {updated_issue.get('message')}"
- )
- return {
- "status": "error",
- "message": updated_issue.get("message"),
- }
+ issue_id = context["issue"]["id"]
+ document_id = context["issue"]["document_id"]
+ connector_id_from_context = context.get("workspace", {}).get("id")
- logger.info(
- f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
- )
+ team = context.get("team", {})
+ new_state_id = _resolve_state(team, new_state_name)
+ new_assignee_id = _resolve_assignee(team, new_assignee_email)
+ new_label_ids = _resolve_labels(team, new_label_names)
- if final_document_id is not None:
logger.info(
- f"Updating knowledge base for document {final_document_id}..."
+ f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
)
- kb_service = LinearKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=final_document_id,
- issue_id=final_issue_id,
- user_id=user_id,
- search_space_id=search_space_id,
+ result = request_approval(
+ action_type="linear_issue_update",
+ tool_name="update_linear_issue",
+ params={
+ "issue_id": issue_id,
+ "document_id": document_id,
+ "new_title": new_title,
+ "new_description": new_description,
+ "new_state_id": new_state_id,
+ "new_assignee_id": new_assignee_id,
+ "new_priority": new_priority,
+ "new_label_ids": new_label_ids,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
)
- if kb_result["status"] == "success":
- logger.info(
- f"Knowledge base successfully updated for issue {final_issue_id}"
- )
- kb_message = " Your knowledge base has also been updated."
- elif kb_result["status"] == "not_indexed":
- kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
- else:
- logger.warning(
- f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
- )
- kb_message = " Your knowledge base will be updated in the next scheduled sync."
- else:
- kb_message = ""
- identifier = updated_issue.get("identifier")
- default_msg = f"Issue {identifier} updated successfully."
- return {
- "status": "success",
- "identifier": identifier,
- "url": updated_issue.get("url"),
- "message": f"{updated_issue.get('message', default_msg)}{kb_message}",
- }
+ if result.rejected:
+ logger.info("Linear issue update rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_issue_id = result.params.get("issue_id", issue_id)
+ final_document_id = result.params.get("document_id", document_id)
+ final_new_title = result.params.get("new_title", new_title)
+ final_new_description = result.params.get(
+ "new_description", new_description
+ )
+ final_new_state_id = result.params.get("new_state_id", new_state_id)
+ final_new_assignee_id = result.params.get(
+ "new_assignee_id", new_assignee_id
+ )
+ final_new_priority = result.params.get("new_priority", new_priority)
+ final_new_label_ids: list[str] | None = result.params.get(
+ "new_label_ids", new_label_ids
+ )
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+
+ if not final_connector_id:
+ logger.error("No connector found for this issue")
+ return {
+ "status": "error",
+ "message": "No connector found for this issue.",
+ }
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Linear connector is invalid or has been disconnected.",
+ }
+ logger.info(f"Validated Linear connector: id={final_connector_id}")
+
+ logger.info(
+ f"Updating Linear issue with final params: issue_id={final_issue_id}"
+ )
+ linear_client = LinearConnector(
+ session=db_session, connector_id=final_connector_id
+ )
+ updated_issue = await linear_client.update_issue(
+ issue_id=final_issue_id,
+ title=final_new_title,
+ description=final_new_description,
+ state_id=final_new_state_id,
+ assignee_id=final_new_assignee_id,
+ priority=final_new_priority,
+ label_ids=final_new_label_ids,
+ )
+
+ if updated_issue.get("status") == "error":
+ logger.error(
+ f"Failed to update Linear issue: {updated_issue.get('message')}"
+ )
+ return {
+ "status": "error",
+ "message": updated_issue.get("message"),
+ }
+
+ logger.info(
+ f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
+ )
+
+ if final_document_id is not None:
+ logger.info(
+ f"Updating knowledge base for document {final_document_id}..."
+ )
+ kb_service = LinearKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=final_document_id,
+ issue_id=final_issue_id,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ )
+ if kb_result["status"] == "success":
+ logger.info(
+ f"Knowledge base successfully updated for issue {final_issue_id}"
+ )
+ kb_message = " Your knowledge base has also been updated."
+ elif kb_result["status"] == "not_indexed":
+ kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
+ else:
+ logger.warning(
+ f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
+ )
+ kb_message = " Your knowledge base will be updated in the next scheduled sync."
+ else:
+ kb_message = ""
+
+ identifier = updated_issue.get("identifier")
+ default_msg = f"Issue {identifier} updated successfully."
+ return {
+ "status": "success",
+ "identifier": identifier,
+ "url": updated_issue.get("url"),
+ "message": f"{updated_issue.get('message', default_msg)}{kb_message}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
index 0a24a988f..65c177d7a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
@@ -17,6 +18,23 @@ def create_create_luma_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_luma_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_luma_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_luma_event(
name: str,
@@ -40,83 +58,86 @@ def create_create_luma_event_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
- connector = await get_luma_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Luma connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_luma_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Luma connector found."}
- result = request_approval(
- action_type="luma_create_event",
- tool_name="create_luma_event",
- params={
- "name": name,
- "start_at": start_at,
- "end_at": end_at,
- "description": description,
- "timezone": timezone,
- },
- context={"connector_id": connector.id},
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Event was not created.",
- }
-
- final_name = result.params.get("name", name)
- final_start = result.params.get("start_at", start_at)
- final_end = result.params.get("end_at", end_at)
- final_desc = result.params.get("description", description)
- final_tz = result.params.get("timezone", timezone)
-
- api_key = get_api_key(connector)
- headers = luma_headers(api_key)
-
- body: dict[str, Any] = {
- "name": final_name,
- "start_at": final_start,
- "end_at": final_end,
- "timezone": final_tz,
- }
- if final_desc:
- body["description_md"] = final_desc
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- resp = await client.post(
- f"{LUMA_API}/event/create",
- headers=headers,
- json=body,
+ result = request_approval(
+ action_type="luma_create_event",
+ tool_name="create_luma_event",
+ params={
+ "name": name,
+ "start_at": start_at,
+ "end_at": end_at,
+ "description": description,
+ "timezone": timezone,
+ },
+ context={"connector_id": connector.id},
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Luma API key is invalid.",
- "connector_type": "luma",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Luma Plus subscription required to create events via API.",
- }
- if resp.status_code not in (200, 201):
- return {
- "status": "error",
- "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Event was not created.",
+ }
- data = resp.json()
- event_id = data.get("api_id") or data.get("event", {}).get("api_id")
+ final_name = result.params.get("name", name)
+ final_start = result.params.get("start_at", start_at)
+ final_end = result.params.get("end_at", end_at)
+ final_desc = result.params.get("description", description)
+ final_tz = result.params.get("timezone", timezone)
- return {
- "status": "success",
- "event_id": event_id,
- "message": f"Event '{final_name}' created on Luma.",
- }
+ api_key = get_api_key(connector)
+ headers = luma_headers(api_key)
+
+ body: dict[str, Any] = {
+ "name": final_name,
+ "start_at": final_start,
+ "end_at": final_end,
+ "timezone": final_tz,
+ }
+ if final_desc:
+ body["description_md"] = final_desc
+
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ resp = await client.post(
+ f"{LUMA_API}/event/create",
+ headers=headers,
+ json=body,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Luma API key is invalid.",
+ "connector_type": "luma",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Luma Plus subscription required to create events via API.",
+ }
+ if resp.status_code not in (200, 201):
+ return {
+ "status": "error",
+ "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
+ }
+
+ data = resp.json()
+ event_id = data.get("api_id") or data.get("event", {}).get("api_id")
+
+ return {
+ "status": "success",
+ "event_id": event_id,
+ "message": f"Event '{final_name}' created on Luma.",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
index aec5ad220..6885c2049 100644
--- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
+++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_list_luma_events_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the list_luma_events tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured list_luma_events tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def list_luma_events(
max_results: int = 25,
@@ -28,77 +47,80 @@ def create_list_luma_events_tool(
Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50)
try:
- connector = await get_luma_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Luma connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_luma_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Luma connector found."}
- api_key = get_api_key(connector)
- headers = luma_headers(api_key)
+ api_key = get_api_key(connector)
+ headers = luma_headers(api_key)
- all_entries: list[dict] = []
- cursor = None
+ all_entries: list[dict] = []
+ cursor = None
- async with httpx.AsyncClient(timeout=20.0) as client:
- while len(all_entries) < max_results:
- params: dict[str, Any] = {
- "limit": min(100, max_results - len(all_entries))
- }
- if cursor:
- params["cursor"] = cursor
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ while len(all_entries) < max_results:
+ params: dict[str, Any] = {
+ "limit": min(100, max_results - len(all_entries))
+ }
+ if cursor:
+ params["cursor"] = cursor
- resp = await client.get(
- f"{LUMA_API}/calendar/list-events",
- headers=headers,
- params=params,
+ resp = await client.get(
+ f"{LUMA_API}/calendar/list-events",
+ headers=headers,
+ params=params,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Luma API key is invalid.",
+ "connector_type": "luma",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Luma API error: {resp.status_code}",
+ }
+
+ data = resp.json()
+ entries = data.get("entries", [])
+ if not entries:
+ break
+ all_entries.extend(entries)
+
+ next_cursor = data.get("next_cursor")
+ if not next_cursor:
+ break
+ cursor = next_cursor
+
+ events = []
+ for entry in all_entries[:max_results]:
+ ev = entry.get("event", {})
+ geo = ev.get("geo_info", {})
+ events.append(
+ {
+ "event_id": entry.get("api_id"),
+ "name": ev.get("name", "Untitled"),
+ "start_at": ev.get("start_at", ""),
+ "end_at": ev.get("end_at", ""),
+ "timezone": ev.get("timezone", ""),
+ "location": geo.get("name", ""),
+ "url": ev.get("url", ""),
+ "visibility": ev.get("visibility", ""),
+ }
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Luma API key is invalid.",
- "connector_type": "luma",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Luma API error: {resp.status_code}",
- }
-
- data = resp.json()
- entries = data.get("entries", [])
- if not entries:
- break
- all_entries.extend(entries)
-
- next_cursor = data.get("next_cursor")
- if not next_cursor:
- break
- cursor = next_cursor
-
- events = []
- for entry in all_entries[:max_results]:
- ev = entry.get("event", {})
- geo = ev.get("geo_info", {})
- events.append(
- {
- "event_id": entry.get("api_id"),
- "name": ev.get("name", "Untitled"),
- "start_at": ev.get("start_at", ""),
- "end_at": ev.get("end_at", ""),
- "timezone": ev.get("timezone", ""),
- "location": geo.get("name", ""),
- "url": ev.get("url", ""),
- "visibility": ev.get("visibility", ""),
- }
- )
-
- return {"status": "success", "events": events, "total": len(events)}
+ return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py
index b37a9d617..a8484e9c0 100644
--- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py
+++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_read_luma_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_luma_event tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_luma_event tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event.
@@ -26,60 +45,63 @@ def create_read_luma_event_tool(
Dictionary with status and full event details including
description, attendees count, meeting URL.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
- connector = await get_luma_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Luma connector found."}
-
- api_key = get_api_key(connector)
- headers = luma_headers(api_key)
-
- async with httpx.AsyncClient(timeout=15.0) as client:
- resp = await client.get(
- f"{LUMA_API}/events/{event_id}",
- headers=headers,
+ async with async_session_maker() as db_session:
+ connector = await get_luma_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Luma connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Luma API key is invalid.",
- "connector_type": "luma",
- }
- if resp.status_code == 404:
- return {
- "status": "not_found",
- "message": f"Event '{event_id}' not found.",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Luma API error: {resp.status_code}",
+ api_key = get_api_key(connector)
+ headers = luma_headers(api_key)
+
+ async with httpx.AsyncClient(timeout=15.0) as client:
+ resp = await client.get(
+ f"{LUMA_API}/events/{event_id}",
+ headers=headers,
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Luma API key is invalid.",
+ "connector_type": "luma",
+ }
+ if resp.status_code == 404:
+ return {
+ "status": "not_found",
+ "message": f"Event '{event_id}' not found.",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Luma API error: {resp.status_code}",
+ }
+
+ data = resp.json()
+ ev = data.get("event", data)
+ geo = ev.get("geo_info", {})
+
+ event_detail = {
+ "event_id": event_id,
+ "name": ev.get("name", ""),
+ "description": ev.get("description", ""),
+ "start_at": ev.get("start_at", ""),
+ "end_at": ev.get("end_at", ""),
+ "timezone": ev.get("timezone", ""),
+ "location_name": geo.get("name", ""),
+ "address": geo.get("address", ""),
+ "url": ev.get("url", ""),
+ "meeting_url": ev.get("meeting_url", ""),
+ "visibility": ev.get("visibility", ""),
+ "cover_url": ev.get("cover_url", ""),
}
- data = resp.json()
- ev = data.get("event", data)
- geo = ev.get("geo_info", {})
-
- event_detail = {
- "event_id": event_id,
- "name": ev.get("name", ""),
- "description": ev.get("description", ""),
- "start_at": ev.get("start_at", ""),
- "end_at": ev.get("end_at", ""),
- "timezone": ev.get("timezone", ""),
- "location_name": geo.get("name", ""),
- "address": geo.get("address", ""),
- "url": ev.get("url", ""),
- "meeting_url": ev.get("meeting_url", ""),
- "visibility": ev.get("visibility", ""),
- "cover_url": ev.get("cover_url", ""),
- }
-
- return {"status": "success", "event": event_detail}
+ return {"status": "success", "event": event_detail}
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py
index 6efffe960..6ec95e9f0 100644
--- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
+from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
@@ -20,8 +21,17 @@ def create_create_notion_page_tool(
"""
Factory function to create the create_notion_page tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker`. This is critical for the compiled-agent
+ cache: the compiled graph (and therefore this closure) is reused
+ across HTTP requests, so capturing a per-request session here would
+ surface stale/closed sessions on cache hits. Per-call sessions also
+ keep the request's outer transaction free of long-running Notion API
+ blocking.
+
Args:
- db_session: Database session for accessing Notion connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +39,7 @@ def create_create_notion_page_tool(
Returns:
Configured create_notion_page tool
"""
+ del db_session # per-call session — see docstring
@tool
async def create_notion_page(
@@ -67,7 +78,7 @@ def create_create_notion_page_tool(
"""
logger.info(f"create_notion_page called: title='{title}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@@ -77,154 +88,157 @@ def create_create_notion_page_tool(
}
try:
- metadata_service = NotionToolMetadataService(db_session)
- context = await metadata_service.get_creation_context(
- search_space_id, user_id
- )
-
- if "error" in context:
- logger.error(f"Failed to fetch creation context: {context['error']}")
- return {
- "status": "error",
- "message": context["error"],
- }
-
- accounts = context.get("accounts", [])
- if accounts and all(a.get("auth_expired") for a in accounts):
- logger.warning("All Notion accounts have expired authentication")
- return {
- "status": "auth_error",
- "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "notion",
- }
-
- logger.info(f"Requesting approval for creating Notion page: '{title}'")
- result = request_approval(
- action_type="notion_page_creation",
- tool_name="create_notion_page",
- params={
- "title": title,
- "content": content,
- "parent_page_id": None,
- "connector_id": connector_id,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Notion page creation rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_title = result.params.get("title", title)
- final_content = result.params.get("content", content)
- final_parent_page_id = result.params.get("parent_page_id")
- final_connector_id = result.params.get("connector_id", connector_id)
-
- if not final_title or not final_title.strip():
- logger.error("Title is empty or contains only whitespace")
- return {
- "status": "error",
- "message": "Page title cannot be empty. Please provide a valid title.",
- }
-
- logger.info(
- f"Creating Notion page with final params: title='{final_title}'"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- actual_connector_id = final_connector_id
- if actual_connector_id is None:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
+ async with async_session_maker() as db_session:
+ metadata_service = NotionToolMetadataService(db_session)
+ context = await metadata_service.get_creation_context(
+ search_space_id, user_id
)
- connector = result.scalars().first()
- if not connector:
- logger.warning(
- f"No Notion connector found for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "No Notion connector found. Please connect Notion in your workspace settings.",
- }
-
- actual_connector_id = connector.id
- logger.info(f"Found Notion connector: id={actual_connector_id}")
- else:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == actual_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
- )
- connector = result.scalars().first()
-
- if not connector:
+ if "error" in context:
logger.error(
- f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
+ f"Failed to fetch creation context: {context['error']}"
)
return {
"status": "error",
- "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ "message": context["error"],
}
- logger.info(f"Validated Notion connector: id={actual_connector_id}")
- notion_connector = NotionHistoryConnector(
- session=db_session,
- connector_id=actual_connector_id,
- )
+ accounts = context.get("accounts", [])
+ if accounts and all(a.get("auth_expired") for a in accounts):
+ logger.warning("All Notion accounts have expired authentication")
+ return {
+ "status": "auth_error",
+ "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "notion",
+ }
- result = await notion_connector.create_page(
- title=final_title,
- content=final_content,
- parent_page_id=final_parent_page_id,
- )
- logger.info(
- f"create_page result: {result.get('status')} - {result.get('message', '')}"
- )
+ logger.info(f"Requesting approval for creating Notion page: '{title}'")
+ result = request_approval(
+ action_type="notion_page_creation",
+ tool_name="create_notion_page",
+ params={
+ "title": title,
+ "content": content,
+ "parent_page_id": None,
+ "connector_id": connector_id,
+ },
+ context=context,
+ )
- if result.get("status") == "success":
- kb_message_suffix = ""
- try:
- from app.services.notion import NotionKBSyncService
+ if result.rejected:
+ logger.info("Notion page creation rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
- kb_service = NotionKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- page_id=result.get("page_id"),
- page_title=result.get("title", final_title),
- page_url=result.get("url"),
- content=final_content,
- connector_id=actual_connector_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
- if kb_result["status"] == "success":
- kb_message_suffix = (
- " Your knowledge base has also been updated."
+ final_title = result.params.get("title", title)
+ final_content = result.params.get("content", content)
+ final_parent_page_id = result.params.get("parent_page_id")
+ final_connector_id = result.params.get("connector_id", connector_id)
+
+ if not final_title or not final_title.strip():
+ logger.error("Title is empty or contains only whitespace")
+ return {
+ "status": "error",
+ "message": "Page title cannot be empty. Please provide a valid title.",
+ }
+
+ logger.info(
+ f"Creating Notion page with final params: title='{final_title}'"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ actual_connector_id = final_connector_id
+ if actual_connector_id is None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
)
- else:
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.warning(
+ f"No Notion connector found for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "No Notion connector found. Please connect Notion in your workspace settings.",
+ }
+
+ actual_connector_id = connector.id
+ logger.info(f"Found Notion connector: id={actual_connector_id}")
+ else:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == actual_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ }
+ logger.info(f"Validated Notion connector: id={actual_connector_id}")
+
+ notion_connector = NotionHistoryConnector(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+
+ result = await notion_connector.create_page(
+ title=final_title,
+ content=final_content,
+ parent_page_id=final_parent_page_id,
+ )
+ logger.info(
+ f"create_page result: {result.get('status')} - {result.get('message', '')}"
+ )
+
+ if result.get("status") == "success":
+ kb_message_suffix = ""
+ try:
+ from app.services.notion import NotionKBSyncService
+
+ kb_service = NotionKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ page_id=result.get("page_id"),
+ page_title=result.get("title", final_title),
+ page_url=result.get("url"),
+ content=final_content,
+ connector_id=actual_connector_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
- result["message"] = result.get("message", "") + kb_message_suffix
+ result["message"] = result.get("message", "") + kb_message_suffix
- return result
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py
index 07f7583d2..7b85da4c2 100644
--- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
+from app.db import async_session_maker
from app.services.notion.tool_metadata_service import NotionToolMetadataService
logger = logging.getLogger(__name__)
@@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
"""
Factory function to create the delete_notion_page tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
Args:
- db_session: Database session for accessing Notion connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for finding the correct Notion connector
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
Returns:
Configured delete_notion_page tool
"""
+ del db_session # per-call session — see docstring
@tool
async def delete_notion_page(
@@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@@ -73,164 +81,167 @@ def create_delete_notion_page_tool(
}
try:
- # Get page context (page_id, account, title) from indexed data
- metadata_service = NotionToolMetadataService(db_session)
- context = await metadata_service.get_delete_context(
- search_space_id, user_id, page_title
- )
-
- if "error" in context:
- error_msg = context["error"]
- # Check if it's a "not found" error (softer handling for LLM)
- if "not found" in error_msg.lower():
- logger.warning(f"Page not found: {error_msg}")
- return {
- "status": "not_found",
- "message": error_msg,
- }
- else:
- logger.error(f"Failed to fetch delete context: {error_msg}")
- return {
- "status": "error",
- "message": error_msg,
- }
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Notion account %s has expired authentication",
- account.get("id"),
+ async with async_session_maker() as db_session:
+ # Get page context (page_id, account, title) from indexed data
+ metadata_service = NotionToolMetadataService(db_session)
+ context = await metadata_service.get_delete_context(
+ search_space_id, user_id, page_title
)
- return {
- "status": "auth_error",
- "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
- }
- page_id = context.get("page_id")
- connector_id_from_context = account.get("id")
- document_id = context.get("document_id")
-
- logger.info(
- f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
- )
-
- result = request_approval(
- action_type="notion_page_deletion",
- tool_name="delete_notion_page",
- params={
- "page_id": page_id,
- "connector_id": connector_id_from_context,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Notion page deletion rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- logger.info(
- f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- # Validate the connector
- if final_connector_id:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
- )
- connector = result.scalars().first()
-
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
- }
- actual_connector_id = connector.id
- logger.info(f"Validated Notion connector: id={actual_connector_id}")
- else:
- logger.error("No connector found for this page")
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- # Create connector instance
- notion_connector = NotionHistoryConnector(
- session=db_session,
- connector_id=actual_connector_id,
- )
-
- # Delete the page from Notion
- result = await notion_connector.delete_page(page_id=final_page_id)
- logger.info(
- f"delete_page result: {result.get('status')} - {result.get('message', '')}"
- )
-
- # If deletion was successful and user wants to delete from KB
- deleted_from_kb = False
- if (
- result.get("status") == "success"
- and final_delete_from_kb
- and document_id
- ):
- try:
- from sqlalchemy.future import select
-
- from app.db import Document
-
- # Get the document
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- document = doc_result.scalars().first()
-
- if document:
- await db_session.delete(document)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
+ if "error" in context:
+ error_msg = context["error"]
+ # Check if it's a "not found" error (softer handling for LLM)
+ if "not found" in error_msg.lower():
+ logger.warning(f"Page not found: {error_msg}")
+ return {
+ "status": "not_found",
+ "message": error_msg,
+ }
else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- result["warning"] = (
- f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
- )
+ logger.error(f"Failed to fetch delete context: {error_msg}")
+ return {
+ "status": "error",
+ "message": error_msg,
+ }
- # Update result with KB deletion status
- if result.get("status") == "success":
- result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- result["message"] = (
- f"{result.get('message', '')} (also removed from knowledge base)"
+ account = context.get("account", {})
+ if account.get("auth_expired"):
+ logger.warning(
+ "Notion account %s has expired authentication",
+ account.get("id"),
)
+ return {
+ "status": "auth_error",
+ "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
+ }
- return result
+ page_id = context.get("page_id")
+ connector_id_from_context = account.get("id")
+ document_id = context.get("document_id")
+
+ logger.info(
+ f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
+ )
+
+ result = request_approval(
+ action_type="notion_page_deletion",
+ tool_name="delete_notion_page",
+ params={
+ "page_id": page_id,
+ "connector_id": connector_id_from_context,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Notion page deletion rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_page_id = result.params.get("page_id", page_id)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ logger.info(
+ f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ # Validate the connector
+ if final_connector_id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Validated Notion connector: id={actual_connector_id}")
+ else:
+ logger.error("No connector found for this page")
+ return {
+ "status": "error",
+ "message": "No connector found for this page.",
+ }
+
+ # Create connector instance
+ notion_connector = NotionHistoryConnector(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+
+ # Delete the page from Notion
+ result = await notion_connector.delete_page(page_id=final_page_id)
+ logger.info(
+ f"delete_page result: {result.get('status')} - {result.get('message', '')}"
+ )
+
+ # If deletion was successful and user wants to delete from KB
+ deleted_from_kb = False
+ if (
+ result.get("status") == "success"
+ and final_delete_from_kb
+ and document_id
+ ):
+ try:
+ from sqlalchemy.future import select
+
+ from app.db import Document
+
+ # Get the document
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ document = doc_result.scalars().first()
+
+ if document:
+ await db_session.delete(document)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ result["warning"] = (
+ f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
+ )
+
+ # Update result with KB deletion status
+ if result.get("status") == "success":
+ result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ result["message"] = (
+ f"{result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py
index 85c08177c..df757476a 100644
--- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py
+++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
+from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
@@ -20,8 +21,14 @@ def create_update_notion_page_tool(
"""
Factory function to create the update_notion_page tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache (see
+ ``create_create_notion_page_tool`` for the full rationale).
+
Args:
- db_session: Database session for accessing Notion connector
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@@ -29,6 +36,7 @@ def create_update_notion_page_tool(
Returns:
Configured update_notion_page tool
"""
+ del db_session # per-call session — see docstring
@tool
async def update_notion_page(
@@ -71,7 +79,7 @@ def create_update_notion_page_tool(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@@ -88,152 +96,155 @@ def create_update_notion_page_tool(
}
try:
- metadata_service = NotionToolMetadataService(db_session)
- context = await metadata_service.get_update_context(
- search_space_id, user_id, page_title
- )
-
- if "error" in context:
- error_msg = context["error"]
- # Check if it's a "not found" error (softer handling for LLM)
- if "not found" in error_msg.lower():
- logger.warning(f"Page not found: {error_msg}")
- return {
- "status": "not_found",
- "message": error_msg,
- }
- else:
- logger.error(f"Failed to fetch update context: {error_msg}")
- return {
- "status": "error",
- "message": error_msg,
- }
-
- account = context.get("account", {})
- if account.get("auth_expired"):
- logger.warning(
- "Notion account %s has expired authentication",
- account.get("id"),
- )
- return {
- "status": "auth_error",
- "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
- }
-
- page_id = context.get("page_id")
- document_id = context.get("document_id")
- connector_id_from_context = context.get("account", {}).get("id")
-
- logger.info(
- f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
- )
- result = request_approval(
- action_type="notion_page_update",
- tool_name="update_notion_page",
- params={
- "page_id": page_id,
- "content": content,
- "connector_id": connector_id_from_context,
- },
- context=context,
- )
-
- if result.rejected:
- logger.info("Notion page update rejected by user")
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_page_id = result.params.get("page_id", page_id)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get(
- "connector_id", connector_id_from_context
- )
-
- logger.info(
- f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
- )
-
- from sqlalchemy.future import select
-
- from app.db import SearchSourceConnector, SearchSourceConnectorType
-
- if final_connector_id:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.NOTION_CONNECTOR,
- )
- )
- connector = result.scalars().first()
-
- if not connector:
- logger.error(
- f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
- )
- return {
- "status": "error",
- "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
- }
- actual_connector_id = connector.id
- logger.info(f"Validated Notion connector: id={actual_connector_id}")
- else:
- logger.error("No connector found for this page")
- return {
- "status": "error",
- "message": "No connector found for this page.",
- }
-
- notion_connector = NotionHistoryConnector(
- session=db_session,
- connector_id=actual_connector_id,
- )
-
- result = await notion_connector.update_page(
- page_id=final_page_id,
- content=final_content,
- )
- logger.info(
- f"update_page result: {result.get('status')} - {result.get('message', '')}"
- )
-
- if result.get("status") == "success" and document_id is not None:
- from app.services.notion import NotionKBSyncService
-
- logger.info(f"Updating knowledge base for document {document_id}...")
- kb_service = NotionKBSyncService(db_session)
- kb_result = await kb_service.sync_after_update(
- document_id=document_id,
- appended_content=final_content,
- user_id=user_id,
- search_space_id=search_space_id,
- appended_block_ids=result.get("appended_block_ids"),
+ async with async_session_maker() as db_session:
+ metadata_service = NotionToolMetadataService(db_session)
+ context = await metadata_service.get_update_context(
+ search_space_id, user_id, page_title
)
- if kb_result["status"] == "success":
- result["message"] = (
- f"{result['message']}. Your knowledge base has also been updated."
- )
- logger.info(
- f"Knowledge base successfully updated for page {final_page_id}"
- )
- elif kb_result["status"] == "not_indexed":
- result["message"] = (
- f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
- )
- else:
- result["message"] = (
- f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
- )
+ if "error" in context:
+ error_msg = context["error"]
+ # Check if it's a "not found" error (softer handling for LLM)
+ if "not found" in error_msg.lower():
+ logger.warning(f"Page not found: {error_msg}")
+ return {
+ "status": "not_found",
+ "message": error_msg,
+ }
+ else:
+ logger.error(f"Failed to fetch update context: {error_msg}")
+ return {
+ "status": "error",
+ "message": error_msg,
+ }
+
+ account = context.get("account", {})
+ if account.get("auth_expired"):
logger.warning(
- f"KB update failed for page {final_page_id}: {kb_result['message']}"
+ "Notion account %s has expired authentication",
+ account.get("id"),
+ )
+ return {
+ "status": "auth_error",
+ "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
+ }
+
+ page_id = context.get("page_id")
+ document_id = context.get("document_id")
+ connector_id_from_context = context.get("account", {}).get("id")
+
+ logger.info(
+ f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
+ )
+ result = request_approval(
+ action_type="notion_page_update",
+ tool_name="update_notion_page",
+ params={
+ "page_id": page_id,
+ "content": content,
+ "connector_id": connector_id_from_context,
+ },
+ context=context,
+ )
+
+ if result.rejected:
+ logger.info("Notion page update rejected by user")
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_page_id = result.params.get("page_id", page_id)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get(
+ "connector_id", connector_id_from_context
+ )
+
+ logger.info(
+ f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
+ )
+
+ from sqlalchemy.future import select
+
+ from app.db import SearchSourceConnector, SearchSourceConnectorType
+
+ if final_connector_id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ logger.error(
+ f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
+ )
+ return {
+ "status": "error",
+ "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
+ }
+ actual_connector_id = connector.id
+ logger.info(f"Validated Notion connector: id={actual_connector_id}")
+ else:
+ logger.error("No connector found for this page")
+ return {
+ "status": "error",
+ "message": "No connector found for this page.",
+ }
+
+ notion_connector = NotionHistoryConnector(
+ session=db_session,
+ connector_id=actual_connector_id,
+ )
+
+ result = await notion_connector.update_page(
+ page_id=final_page_id,
+ content=final_content,
+ )
+ logger.info(
+ f"update_page result: {result.get('status')} - {result.get('message', '')}"
+ )
+
+ if result.get("status") == "success" and document_id is not None:
+ from app.services.notion import NotionKBSyncService
+
+ logger.info(
+ f"Updating knowledge base for document {document_id}..."
+ )
+ kb_service = NotionKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_update(
+ document_id=document_id,
+ appended_content=final_content,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ appended_block_ids=result.get("appended_block_ids"),
)
- return result
+ if kb_result["status"] == "success":
+ result["message"] = (
+ f"{result['message']}. Your knowledge base has also been updated."
+ )
+ logger.info(
+ f"Knowledge base successfully updated for page {final_page_id}"
+ )
+ elif kb_result["status"] == "not_indexed":
+ result["message"] = (
+ f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
+ )
+ else:
+ result["message"] = (
+ f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
+ )
+ logger.warning(
+ f"KB update failed for page {final_page_id}: {kb_result['message']}"
+ )
+
+ return result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py
index 21272e01d..5f199a41b 100644
--- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py
@@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
-from app.db import SearchSourceConnector, SearchSourceConnectorType
+from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the create_onedrive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured create_onedrive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def create_onedrive_file(
name: str,
@@ -70,173 +87,178 @@ def create_create_onedrive_file_tool(
"""
logger.info(f"create_onedrive_file called: name='{name}'")
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
- result = await db_session.execute(
- select(SearchSourceConnector).filter(
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
- )
- )
- connectors = result.scalars().all()
-
- if not connectors:
- return {
- "status": "error",
- "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
- }
-
- accounts = []
- for c in connectors:
- cfg = c.config or {}
- accounts.append(
- {
- "id": c.id,
- "name": c.name,
- "user_email": cfg.get("user_email"),
- "auth_expired": cfg.get("auth_expired", False),
- }
- )
-
- if all(a.get("auth_expired") for a in accounts):
- return {
- "status": "auth_error",
- "message": "All connected OneDrive accounts need re-authentication.",
- "connector_type": "onedrive",
- }
-
- parent_folders: dict[int, list[dict[str, str]]] = {}
- for acc in accounts:
- cid = acc["id"]
- if acc.get("auth_expired"):
- parent_folders[cid] = []
- continue
- try:
- client = OneDriveClient(session=db_session, connector_id=cid)
- items, err = await client.list_children("root")
- if err:
- logger.warning(
- "Failed to list folders for connector %s: %s", cid, err
- )
- parent_folders[cid] = []
- else:
- parent_folders[cid] = [
- {"folder_id": item["id"], "name": item["name"]}
- for item in items
- if item.get("folder") is not None
- and item.get("id")
- and item.get("name")
- ]
- except Exception:
- logger.warning(
- "Error fetching folders for connector %s", cid, exc_info=True
- )
- parent_folders[cid] = []
-
- context: dict[str, Any] = {
- "accounts": accounts,
- "parent_folders": parent_folders,
- }
-
- result = request_approval(
- action_type="onedrive_file_creation",
- tool_name="create_onedrive_file",
- params={
- "name": name,
- "content": content,
- "connector_id": None,
- "parent_folder_id": None,
- },
- context=context,
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
-
- final_name = result.params.get("name", name)
- final_content = result.params.get("content", content)
- final_connector_id = result.params.get("connector_id")
- final_parent_folder_id = result.params.get("parent_folder_id")
-
- if not final_name or not final_name.strip():
- return {"status": "error", "message": "File name cannot be empty."}
-
- final_name = _ensure_docx_extension(final_name)
-
- if final_connector_id is not None:
+ async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
- SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
- connector = result.scalars().first()
- else:
- connector = connectors[0]
+ connectors = result.scalars().all()
- if not connector:
- return {
- "status": "error",
- "message": "Selected OneDrive connector is invalid.",
+ if not connectors:
+ return {
+ "status": "error",
+ "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
+ }
+
+ accounts = []
+ for c in connectors:
+ cfg = c.config or {}
+ accounts.append(
+ {
+ "id": c.id,
+ "name": c.name,
+ "user_email": cfg.get("user_email"),
+ "auth_expired": cfg.get("auth_expired", False),
+ }
+ )
+
+ if all(a.get("auth_expired") for a in accounts):
+ return {
+ "status": "auth_error",
+ "message": "All connected OneDrive accounts need re-authentication.",
+ "connector_type": "onedrive",
+ }
+
+ parent_folders: dict[int, list[dict[str, str]]] = {}
+ for acc in accounts:
+ cid = acc["id"]
+ if acc.get("auth_expired"):
+ parent_folders[cid] = []
+ continue
+ try:
+ client = OneDriveClient(session=db_session, connector_id=cid)
+ items, err = await client.list_children("root")
+ if err:
+ logger.warning(
+ "Failed to list folders for connector %s: %s", cid, err
+ )
+ parent_folders[cid] = []
+ else:
+ parent_folders[cid] = [
+ {"folder_id": item["id"], "name": item["name"]}
+ for item in items
+ if item.get("folder") is not None
+ and item.get("id")
+ and item.get("name")
+ ]
+ except Exception:
+ logger.warning(
+ "Error fetching folders for connector %s",
+ cid,
+ exc_info=True,
+ )
+ parent_folders[cid] = []
+
+ context: dict[str, Any] = {
+ "accounts": accounts,
+ "parent_folders": parent_folders,
}
- docx_bytes = _markdown_to_docx(final_content or "")
-
- client = OneDriveClient(session=db_session, connector_id=connector.id)
- created = await client.create_file(
- name=final_name,
- parent_id=final_parent_folder_id,
- content=docx_bytes,
- mime_type=DOCX_MIME,
- )
-
- logger.info(
- f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
- )
-
- kb_message_suffix = ""
- try:
- from app.services.onedrive import OneDriveKBSyncService
-
- kb_service = OneDriveKBSyncService(db_session)
- kb_result = await kb_service.sync_after_create(
- file_id=created.get("id"),
- file_name=created.get("name", final_name),
- mime_type=DOCX_MIME,
- web_url=created.get("webUrl"),
- content=final_content,
- connector_id=connector.id,
- search_space_id=search_space_id,
- user_id=user_id,
+ result = request_approval(
+ action_type="onedrive_file_creation",
+ tool_name="create_onedrive_file",
+ params={
+ "name": name,
+ "content": content,
+ "connector_id": None,
+ "parent_folder_id": None,
+ },
+ context=context,
)
- if kb_result["status"] == "success":
- kb_message_suffix = " Your knowledge base has also been updated."
- else:
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- except Exception as kb_err:
- logger.warning(f"KB sync after create failed: {kb_err}")
- kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
- return {
- "status": "success",
- "file_id": created.get("id"),
- "name": created.get("name"),
- "web_url": created.get("webUrl"),
- "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_name = result.params.get("name", name)
+ final_content = result.params.get("content", content)
+ final_connector_id = result.params.get("connector_id")
+ final_parent_folder_id = result.params.get("parent_folder_id")
+
+ if not final_name or not final_name.strip():
+ return {"status": "error", "message": "File name cannot be empty."}
+
+ final_name = _ensure_docx_extension(final_name)
+
+ if final_connector_id is not None:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
+ )
+ )
+ connector = result.scalars().first()
+ else:
+ connector = connectors[0]
+
+ if not connector:
+ return {
+ "status": "error",
+ "message": "Selected OneDrive connector is invalid.",
+ }
+
+ docx_bytes = _markdown_to_docx(final_content or "")
+
+ client = OneDriveClient(session=db_session, connector_id=connector.id)
+ created = await client.create_file(
+ name=final_name,
+ parent_id=final_parent_folder_id,
+ content=docx_bytes,
+ mime_type=DOCX_MIME,
+ )
+
+ logger.info(
+ f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
+ )
+
+ kb_message_suffix = ""
+ try:
+ from app.services.onedrive import OneDriveKBSyncService
+
+ kb_service = OneDriveKBSyncService(db_session)
+ kb_result = await kb_service.sync_after_create(
+ file_id=created.get("id"),
+ file_name=created.get("name", final_name),
+ mime_type=DOCX_MIME,
+ web_url=created.get("webUrl"),
+ content=final_content,
+ connector_id=connector.id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ if kb_result["status"] == "success":
+ kb_message_suffix = (
+ " Your knowledge base has also been updated."
+ )
+ else:
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+ except Exception as kb_err:
+ logger.warning(f"KB sync after create failed: {kb_err}")
+ kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
+
+ return {
+ "status": "success",
+ "file_id": created.get("id"),
+ "name": created.get("name"),
+ "web_url": created.get("webUrl"),
+ "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py
index a7f13b5df..4857ea988 100644
--- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py
+++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py
@@ -13,6 +13,7 @@ from app.db import (
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
+ async_session_maker,
)
logger = logging.getLogger(__name__)
@@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the delete_onedrive_file tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured delete_onedrive_file tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def delete_onedrive_file(
file_name: str,
@@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool(
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
- doc_result = await db_session.execute(
- select(Document)
- .join(
- SearchSourceConnector,
- Document.connector_id == SearchSourceConnector.id,
- )
- .filter(
- and_(
- Document.search_space_id == search_space_id,
- Document.document_type == DocumentType.ONEDRIVE_FILE,
- func.lower(Document.title) == func.lower(file_name),
- SearchSourceConnector.user_id == user_id,
- )
- )
- .order_by(Document.updated_at.desc().nullslast())
- .limit(1)
- )
- document = doc_result.scalars().first()
-
- if not document:
+ async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
@@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
- func.lower(
- cast(
- Document.document_metadata["onedrive_file_name"],
- String,
- )
- )
- == func.lower(file_name),
+ func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
@@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool(
)
document = doc_result.scalars().first()
- if not document:
- return {
- "status": "not_found",
- "message": (
- f"File '{file_name}' not found in your indexed OneDrive files. "
- "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
- "or (3) the file name is different."
- ),
- }
-
- if not document.connector_id:
- return {
- "status": "error",
- "message": "Document has no associated connector.",
- }
-
- meta = document.document_metadata or {}
- file_id = meta.get("onedrive_file_id")
- document_id = document.id
-
- if not file_id:
- return {
- "status": "error",
- "message": "File ID is missing. Please re-index the file.",
- }
-
- conn_result = await db_session.execute(
- select(SearchSourceConnector).filter(
- and_(
- SearchSourceConnector.id == document.connector_id,
- SearchSourceConnector.search_space_id == search_space_id,
- SearchSourceConnector.user_id == user_id,
- SearchSourceConnector.connector_type
- == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
+ if not document:
+ doc_result = await db_session.execute(
+ select(Document)
+ .join(
+ SearchSourceConnector,
+ Document.connector_id == SearchSourceConnector.id,
+ )
+ .filter(
+ and_(
+ Document.search_space_id == search_space_id,
+ Document.document_type == DocumentType.ONEDRIVE_FILE,
+ func.lower(
+ cast(
+ Document.document_metadata[
+ "onedrive_file_name"
+ ],
+ String,
+ )
+ )
+ == func.lower(file_name),
+ SearchSourceConnector.user_id == user_id,
+ )
+ )
+ .order_by(Document.updated_at.desc().nullslast())
+ .limit(1)
)
- )
- )
- connector = conn_result.scalars().first()
- if not connector:
- return {
- "status": "error",
- "message": "OneDrive connector not found or access denied.",
- }
+ document = doc_result.scalars().first()
- cfg = connector.config or {}
- if cfg.get("auth_expired"):
- return {
- "status": "auth_error",
- "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
- "connector_type": "onedrive",
- }
+ if not document:
+ return {
+ "status": "not_found",
+ "message": (
+ f"File '{file_name}' not found in your indexed OneDrive files. "
+ "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
+ "or (3) the file name is different."
+ ),
+ }
- context = {
- "file": {
- "file_id": file_id,
- "name": file_name,
- "document_id": document_id,
- "web_url": meta.get("web_url"),
- },
- "account": {
- "id": connector.id,
- "name": connector.name,
- "user_email": cfg.get("user_email"),
- },
- }
+ if not document.connector_id:
+ return {
+ "status": "error",
+ "message": "Document has no associated connector.",
+ }
- result = request_approval(
- action_type="onedrive_file_trash",
- tool_name="delete_onedrive_file",
- params={
- "file_id": file_id,
- "connector_id": connector.id,
- "delete_from_kb": delete_from_kb,
- },
- context=context,
- )
+ meta = document.document_metadata or {}
+ file_id = meta.get("onedrive_file_id")
+ document_id = document.id
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Do not retry or suggest alternatives.",
- }
+ if not file_id:
+ return {
+ "status": "error",
+ "message": "File ID is missing. Please re-index the file.",
+ }
- final_file_id = result.params.get("file_id", file_id)
- final_connector_id = result.params.get("connector_id", connector.id)
- final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
-
- if final_connector_id != connector.id:
- result = await db_session.execute(
+ conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
- SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
@@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool(
)
)
)
- validated_connector = result.scalars().first()
- if not validated_connector:
+ connector = conn_result.scalars().first()
+ if not connector:
return {
"status": "error",
- "message": "Selected OneDrive connector is invalid or has been disconnected.",
+ "message": "OneDrive connector not found or access denied.",
}
- actual_connector_id = validated_connector.id
- else:
- actual_connector_id = connector.id
- logger.info(
- f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
- )
+ cfg = connector.config or {}
+ if cfg.get("auth_expired"):
+ return {
+ "status": "auth_error",
+ "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
+ "connector_type": "onedrive",
+ }
- client = OneDriveClient(
- session=db_session, connector_id=actual_connector_id
- )
- await client.trash_file(final_file_id)
+ context = {
+ "file": {
+ "file_id": file_id,
+ "name": file_name,
+ "document_id": document_id,
+ "web_url": meta.get("web_url"),
+ },
+ "account": {
+ "id": connector.id,
+ "name": connector.name,
+ "user_email": cfg.get("user_email"),
+ },
+ }
- logger.info(
- f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
- )
-
- trash_result: dict[str, Any] = {
- "status": "success",
- "file_id": final_file_id,
- "message": f"Successfully moved '{file_name}' to the recycle bin.",
- }
-
- deleted_from_kb = False
- if final_delete_from_kb and document_id:
- try:
- doc_result = await db_session.execute(
- select(Document).filter(Document.id == document_id)
- )
- doc = doc_result.scalars().first()
- if doc:
- await db_session.delete(doc)
- await db_session.commit()
- deleted_from_kb = True
- logger.info(
- f"Deleted document {document_id} from knowledge base"
- )
- else:
- logger.warning(f"Document {document_id} not found in KB")
- except Exception as e:
- logger.error(f"Failed to delete document from KB: {e}")
- await db_session.rollback()
- trash_result["warning"] = (
- f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
- )
-
- trash_result["deleted_from_kb"] = deleted_from_kb
- if deleted_from_kb:
- trash_result["message"] = (
- f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ result = request_approval(
+ action_type="onedrive_file_trash",
+ tool_name="delete_onedrive_file",
+ params={
+ "file_id": file_id,
+ "connector_id": connector.id,
+ "delete_from_kb": delete_from_kb,
+ },
+ context=context,
)
- return trash_result
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Do not retry or suggest alternatives.",
+ }
+
+ final_file_id = result.params.get("file_id", file_id)
+ final_connector_id = result.params.get("connector_id", connector.id)
+ final_delete_from_kb = result.params.get(
+ "delete_from_kb", delete_from_kb
+ )
+
+ if final_connector_id != connector.id:
+ result = await db_session.execute(
+ select(SearchSourceConnector).filter(
+ and_(
+ SearchSourceConnector.id == final_connector_id,
+ SearchSourceConnector.search_space_id
+ == search_space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
+ )
+ )
+ )
+ validated_connector = result.scalars().first()
+ if not validated_connector:
+ return {
+ "status": "error",
+ "message": "Selected OneDrive connector is invalid or has been disconnected.",
+ }
+ actual_connector_id = validated_connector.id
+ else:
+ actual_connector_id = connector.id
+
+ logger.info(
+ f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
+ )
+
+ client = OneDriveClient(
+ session=db_session, connector_id=actual_connector_id
+ )
+ await client.trash_file(final_file_id)
+
+ logger.info(
+ f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
+ )
+
+ trash_result: dict[str, Any] = {
+ "status": "success",
+ "file_id": final_file_id,
+ "message": f"Successfully moved '{file_name}' to the recycle bin.",
+ }
+
+ deleted_from_kb = False
+ if final_delete_from_kb and document_id:
+ try:
+ doc_result = await db_session.execute(
+ select(Document).filter(Document.id == document_id)
+ )
+ doc = doc_result.scalars().first()
+ if doc:
+ await db_session.delete(doc)
+ await db_session.commit()
+ deleted_from_kb = True
+ logger.info(
+ f"Deleted document {document_id} from knowledge base"
+ )
+ else:
+ logger.warning(f"Document {document_id} not found in KB")
+ except Exception as e:
+ logger.error(f"Failed to delete document from KB: {e}")
+ await db_session.rollback()
+ trash_result["warning"] = (
+ f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
+ )
+
+ trash_result["deleted_from_kb"] = deleted_from_kb
+ if deleted_from_kb:
+ trash_result["message"] = (
+ f"{trash_result.get('message', '')} (also removed from knowledge base)"
+ )
+
+ return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py
index e8bab36fd..b842d7a20 100644
--- a/surfsense_backend/app/agents/new_chat/tools/registry.py
+++ b/surfsense_backend/app/agents/new_chat/tools/registry.py
@@ -824,13 +824,22 @@ async def build_tools_async(
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
- This function exists because MCP tools require database queries to load user configs,
- while built-in tools are created synchronously from static code.
+ This function exists because MCP tools require database queries to load
+ user configs, while built-in tools are created synchronously from static
+ code.
- Alternative: We could make build_tools() itself async and always query the database,
- but that would force async everywhere even when only using built-in tools. The current
- design keeps the simple case (static tools only) synchronous while supporting dynamic
- database-loaded tools through this async wrapper.
+ Alternative: We could make build_tools() itself async and always query
+ the database, but that would force async everywhere even when only using
+ built-in tools. The current design keeps the simple case (static tools
+ only) synchronous while supporting dynamic database-loaded tools through
+ this async wrapper.
+
+ Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
+ avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
+ the event loop) are kicked off concurrently. Cold-path savings are
+ bounded by the slower of the two — typically MCP at ~200ms-1.7s —
+ so the parallelization recovers the ~50-200ms previously spent
+ serially on built-in construction.
Args:
dependencies: Dict containing all possible dependencies
@@ -843,33 +852,70 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools.
"""
+ import asyncio
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
+ can_load_mcp = (
+ include_mcp_tools
+ and "db_session" in dependencies
+ and "search_space_id" in dependencies
+ )
+
+ # Built-in tool construction is synchronous + CPU-only. Off-loop it so
+ # MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
+ # function over its inputs — safe to thread-shift.
_t0 = time.perf_counter()
- tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
+ builtin_task = asyncio.create_task(
+ asyncio.to_thread(
+ build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
+ )
+ )
+
+ mcp_task: asyncio.Task | None = None
+ if can_load_mcp:
+ mcp_task = asyncio.create_task(
+ load_mcp_tools(
+ dependencies["db_session"],
+ dependencies["search_space_id"],
+ )
+ )
+
+ # Surface failures from each task independently so a flaky MCP
+ # endpoint never poisons built-in tool registration. ``return_exceptions``
+ # gives us per-task exceptions instead of dropping the second result
+ # when the first raises.
+ if mcp_task is not None:
+ builtin_result, mcp_result = await asyncio.gather(
+ builtin_task, mcp_task, return_exceptions=True
+ )
+ else:
+ builtin_result = await builtin_task
+ mcp_result = None
+
+ if isinstance(builtin_result, BaseException):
+ raise builtin_result # built-in registration failure is non-recoverable
+ tools: list[BaseTool] = builtin_result
_perf_log.info(
- "[build_tools_async] Built-in tools in %.3fs (%d tools)",
+ "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(tools),
)
- # Load MCP tools if requested and dependencies are available
- if (
- include_mcp_tools
- and "db_session" in dependencies
- and "search_space_id" in dependencies
- ):
- try:
- _t0 = time.perf_counter()
- mcp_tools = await load_mcp_tools(
- dependencies["db_session"],
- dependencies["search_space_id"],
+ if mcp_task is not None:
+ if isinstance(mcp_result, BaseException):
+ # ``return_exceptions=True`` captures the exception out-of-band,
+ # so ``sys.exc_info()`` is empty here. Pass the captured
+ # exception via ``exc_info=`` to get a real traceback.
+ logging.error(
+ "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
)
+ else:
+ mcp_tools = mcp_result or []
_perf_log.info(
- "[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
+ "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(mcp_tools),
)
@@ -879,8 +925,6 @@ async def build_tools_async(
len(mcp_tools),
[t.name for t in mcp_tools],
)
- except Exception as e:
- logging.exception("Failed to load MCP tools: %s", e)
logging.info(
"Total tools for agent: %d — %s",
diff --git a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py
index b8b1527c7..2965f2f02 100644
--- a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py
+++ b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py
@@ -15,7 +15,7 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
+from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
from app.utils.document_converters import embed_text
@@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
"""
Factory function to create the search_surfsense_docs tool.
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
Args:
- db_session: Database session for executing queries
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
Returns:
A configured tool function for searching Surfsense documentation
"""
+ del db_session # per-call session — see docstring
@tool
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
@@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
Returns:
Relevant documentation content formatted with chunk IDs for citations
"""
- return await search_surfsense_docs_async(
- query=query,
- db_session=db_session,
- top_k=top_k,
- )
+ async with async_session_maker() as db_session:
+ return await search_surfsense_docs_async(
+ query=query,
+ db_session=db_session,
+ top_k=top_k,
+ )
return search_surfsense_docs
diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py
index d7b000853..0fc52b5c7 100644
--- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py
+++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the list_teams_channels tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured list_teams_channels tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to.
@@ -23,63 +42,66 @@ def create_list_teams_channels_tool(
Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name).
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
- connector = await get_teams_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Teams connector found."}
-
- token = await get_access_token(db_session, connector)
- headers = {"Authorization": f"Bearer {token}"}
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- teams_resp = await client.get(
- f"{GRAPH_API}/me/joinedTeams", headers=headers
+ async with async_session_maker() as db_session:
+ connector = await get_teams_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Teams connector found."}
- if teams_resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Teams token expired. Please re-authenticate.",
- "connector_type": "teams",
- }
- if teams_resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Graph API error: {teams_resp.status_code}",
- }
+ token = await get_access_token(db_session, connector)
+ headers = {"Authorization": f"Bearer {token}"}
- teams_data = teams_resp.json().get("value", [])
- result_teams = []
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- for team in teams_data:
- team_id = team["id"]
- ch_resp = await client.get(
- f"{GRAPH_API}/teams/{team_id}/channels",
- headers=headers,
- )
- channels = []
- if ch_resp.status_code == 200:
- channels = [
- {"id": ch["id"], "name": ch.get("displayName", "")}
- for ch in ch_resp.json().get("value", [])
- ]
- result_teams.append(
- {
- "team_id": team_id,
- "team_name": team.get("displayName", ""),
- "channels": channels,
- }
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ teams_resp = await client.get(
+ f"{GRAPH_API}/me/joinedTeams", headers=headers
)
- return {
- "status": "success",
- "teams": result_teams,
- "total_teams": len(result_teams),
- }
+ if teams_resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Teams token expired. Please re-authenticate.",
+ "connector_type": "teams",
+ }
+ if teams_resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Graph API error: {teams_resp.status_code}",
+ }
+
+ teams_data = teams_resp.json().get("value", [])
+ result_teams = []
+
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ for team in teams_data:
+ team_id = team["id"]
+ ch_resp = await client.get(
+ f"{GRAPH_API}/teams/{team_id}/channels",
+ headers=headers,
+ )
+ channels = []
+ if ch_resp.status_code == 200:
+ channels = [
+ {"id": ch["id"], "name": ch.get("displayName", "")}
+ for ch in ch_resp.json().get("value", [])
+ ]
+ result_teams.append(
+ {
+ "team_id": team_id,
+ "team_name": team.get("displayName", ""),
+ "channels": channels,
+ }
+ )
+
+ return {
+ "status": "success",
+ "teams": result_teams,
+ "total_teams": len(result_teams),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py
index d24a7e4d3..0ebda021e 100644
--- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py
+++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py
@@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
+from app.db import async_session_maker
+
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
@@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the read_teams_messages tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured read_teams_messages tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def read_teams_messages(
team_id: str,
@@ -32,65 +51,68 @@ def create_read_teams_messages_tool(
Dictionary with status and a list of messages including
id, sender, content, timestamp.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50)
try:
- connector = await get_teams_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Teams connector found."}
-
- token = await get_access_token(db_session, connector)
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- resp = await client.get(
- f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
- headers={"Authorization": f"Bearer {token}"},
- params={"$top": limit},
+ async with async_session_maker() as db_session:
+ connector = await get_teams_connector(
+ db_session, search_space_id, user_id
)
+ if not connector:
+ return {"status": "error", "message": "No Teams connector found."}
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Teams token expired. Please re-authenticate.",
- "connector_type": "teams",
- }
- if resp.status_code == 403:
- return {
- "status": "error",
- "message": "Insufficient permissions to read this channel.",
- }
- if resp.status_code != 200:
- return {
- "status": "error",
- "message": f"Graph API error: {resp.status_code}",
- }
+ token = await get_access_token(db_session, connector)
- raw_msgs = resp.json().get("value", [])
- messages = []
- for m in raw_msgs:
- sender = m.get("from", {})
- user_info = sender.get("user", {}) if sender else {}
- body = m.get("body", {})
- messages.append(
- {
- "id": m.get("id"),
- "sender": user_info.get("displayName", "Unknown"),
- "content": body.get("content", ""),
- "content_type": body.get("contentType", "text"),
- "timestamp": m.get("createdDateTime", ""),
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ resp = await client.get(
+ f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
+ headers={"Authorization": f"Bearer {token}"},
+ params={"$top": limit},
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Teams token expired. Please re-authenticate.",
+ "connector_type": "teams",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "error",
+ "message": "Insufficient permissions to read this channel.",
+ }
+ if resp.status_code != 200:
+ return {
+ "status": "error",
+ "message": f"Graph API error: {resp.status_code}",
}
- )
- return {
- "status": "success",
- "team_id": team_id,
- "channel_id": channel_id,
- "messages": messages,
- "total": len(messages),
- }
+ raw_msgs = resp.json().get("value", [])
+ messages = []
+ for m in raw_msgs:
+ sender = m.get("from", {})
+ user_info = sender.get("user", {}) if sender else {}
+ body = m.get("body", {})
+ messages.append(
+ {
+ "id": m.get("id"),
+ "sender": user_info.get("displayName", "Unknown"),
+ "content": body.get("content", ""),
+ "content_type": body.get("contentType", "text"),
+ "timestamp": m.get("createdDateTime", ""),
+ }
+ )
+
+ return {
+ "status": "success",
+ "team_id": team_id,
+ "channel_id": channel_id,
+ "messages": messages,
+ "total": len(messages),
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py
index fd8d00870..6f40d27e1 100644
--- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py
+++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
+from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
@@ -17,6 +18,23 @@ def create_send_teams_message_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
+ """
+ Factory function to create the send_teams_message tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+
+ Args:
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+
+ Returns:
+ Configured send_teams_message tool
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def send_teams_message(
team_id: str,
@@ -39,70 +57,73 @@ def create_send_teams_message_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
- if db_session is None or search_space_id is None or user_id is None:
+ if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
- connector = await get_teams_connector(db_session, search_space_id, user_id)
- if not connector:
- return {"status": "error", "message": "No Teams connector found."}
+ async with async_session_maker() as db_session:
+ connector = await get_teams_connector(
+ db_session, search_space_id, user_id
+ )
+ if not connector:
+ return {"status": "error", "message": "No Teams connector found."}
- result = request_approval(
- action_type="teams_send_message",
- tool_name="send_teams_message",
- params={
- "team_id": team_id,
- "channel_id": channel_id,
- "content": content,
- },
- context={"connector_id": connector.id},
- )
-
- if result.rejected:
- return {
- "status": "rejected",
- "message": "User declined. Message was not sent.",
- }
-
- final_content = result.params.get("content", content)
- final_team = result.params.get("team_id", team_id)
- final_channel = result.params.get("channel_id", channel_id)
-
- token = await get_access_token(db_session, connector)
-
- async with httpx.AsyncClient(timeout=20.0) as client:
- resp = await client.post(
- f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
- headers={
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json",
+ result = request_approval(
+ action_type="teams_send_message",
+ tool_name="send_teams_message",
+ params={
+ "team_id": team_id,
+ "channel_id": channel_id,
+ "content": content,
},
- json={"body": {"content": final_content}},
+ context={"connector_id": connector.id},
)
- if resp.status_code == 401:
- return {
- "status": "auth_error",
- "message": "Teams token expired. Please re-authenticate.",
- "connector_type": "teams",
- }
- if resp.status_code == 403:
- return {
- "status": "insufficient_permissions",
- "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
- }
- if resp.status_code not in (200, 201):
- return {
- "status": "error",
- "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
- }
+ if result.rejected:
+ return {
+ "status": "rejected",
+ "message": "User declined. Message was not sent.",
+ }
- msg_data = resp.json()
- return {
- "status": "success",
- "message_id": msg_data.get("id"),
- "message": "Message sent to Teams channel.",
- }
+ final_content = result.params.get("content", content)
+ final_team = result.params.get("team_id", team_id)
+ final_channel = result.params.get("channel_id", channel_id)
+
+ token = await get_access_token(db_session, connector)
+
+ async with httpx.AsyncClient(timeout=20.0) as client:
+ resp = await client.post(
+ f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
+ headers={
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ },
+ json={"body": {"content": final_content}},
+ )
+
+ if resp.status_code == 401:
+ return {
+ "status": "auth_error",
+ "message": "Teams token expired. Please re-authenticate.",
+ "connector_type": "teams",
+ }
+ if resp.status_code == 403:
+ return {
+ "status": "insufficient_permissions",
+ "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
+ }
+ if resp.status_code not in (200, 201):
+ return {
+ "status": "error",
+ "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
+ }
+
+ msg_data = resp.json()
+ return {
+ "status": "success",
+ "message_id": msg_data.get("id"),
+ "message": "Message sent to Teams channel.",
+ }
except Exception as e:
from langgraph.errors import GraphInterrupt
diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py
index 4128ac0dc..fbc9edbba 100644
--- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py
+++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py
@@ -26,7 +26,7 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import SearchSpace, User
+from app.db import SearchSpace, User, async_session_maker
logger = logging.getLogger(__name__)
@@ -295,6 +295,25 @@ def create_update_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
+ """Factory function to create the user-memory update tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+ The session's bound ``commit``/``rollback`` methods are captured at
+ call time, after ``async with`` has bound ``db_session`` locally.
+
+ Args:
+ user_id: ID of the user whose memory document is being updated.
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ llm: Optional LLM for the forced-rewrite path.
+
+ Returns:
+ Configured update_memory tool for the user-memory scope.
+ """
+ del db_session # per-call session — see docstring
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
@@ -311,26 +330,26 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
- result = await db_session.execute(select(User).where(User.id == uid))
- user = result.scalars().first()
- if not user:
- return {"status": "error", "message": "User not found."}
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(select(User).where(User.id == uid))
+ user = result.scalars().first()
+ if not user:
+ return {"status": "error", "message": "User not found."}
- old_memory = user.memory_md
+ old_memory = user.memory_md
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
- llm=llm,
- apply_fn=lambda content: setattr(user, "memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="memory",
- scope="user",
- )
+ return await _save_memory(
+ updated_memory=updated_memory,
+ old_memory=old_memory,
+ llm=llm,
+ apply_fn=lambda content: setattr(user, "memory_md", content),
+ commit_fn=db_session.commit,
+ rollback_fn=db_session.rollback,
+ label="memory",
+ scope="user",
+ )
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
- await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update memory: {e}",
@@ -344,6 +363,27 @@ def create_update_team_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
+ """Factory function to create the team-memory update tool.
+
+ The tool acquires its own short-lived ``AsyncSession`` per call via
+ :data:`async_session_maker` so the closure is safe to share across
+ HTTP requests by the compiled-agent cache. Capturing a per-request
+ session here would surface stale/closed sessions on cache hits.
+ The session's bound ``commit``/``rollback`` methods are captured at
+ call time, after ``async with`` has bound ``db_session`` locally.
+
+ Args:
+ search_space_id: ID of the search space whose team memory is being
+ updated.
+ db_session: Reserved for registry compatibility. Per-call sessions
+ are opened via :data:`async_session_maker` inside the tool body.
+ llm: Optional LLM for the forced-rewrite path.
+
+ Returns:
+ Configured update_memory tool for the team-memory scope.
+ """
+ del db_session # per-call session — see docstring
+
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
@@ -359,28 +399,30 @@ def create_update_team_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
- result = await db_session.execute(
- select(SearchSpace).where(SearchSpace.id == search_space_id)
- )
- space = result.scalars().first()
- if not space:
- return {"status": "error", "message": "Search space not found."}
+ async with async_session_maker() as db_session:
+ result = await db_session.execute(
+ select(SearchSpace).where(SearchSpace.id == search_space_id)
+ )
+ space = result.scalars().first()
+ if not space:
+ return {"status": "error", "message": "Search space not found."}
- old_memory = space.shared_memory_md
+ old_memory = space.shared_memory_md
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
- llm=llm,
- apply_fn=lambda content: setattr(space, "shared_memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="team memory",
- scope="team",
- )
+ return await _save_memory(
+ updated_memory=updated_memory,
+ old_memory=old_memory,
+ llm=llm,
+ apply_fn=lambda content: setattr(
+ space, "shared_memory_md", content
+ ),
+ commit_fn=db_session.commit,
+ rollback_fn=db_session.rollback,
+ label="team memory",
+ scope="team",
+ )
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
- await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update team memory: {e}",
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 14d7f4d23..2c9b4f390 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -421,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
OpenRouterIntegrationService.get_instance().stop_background_refresh()
+async def _warm_agent_jit_caches() -> None:
+ """Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
+
+ Why
+ ----
+ A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
+ generation chain takes 1.5-2 seconds of pure CPU on first invocation
+ inside any Python process: the graph compiler builds reducers,
+ Pydantic v2 generates and JITs validator schemas, deepagents
+ eagerly compiles its general-purpose subagent, etc. Subsequent
+ compiles in the same process pay only ~50% of that cost (the lazy
+ JIT bits are cached in module-level dicts).
+
+ Doing one throwaway compile during ``lifespan`` startup pre-pays
+ that cost so the *first real request* doesn't. We do NOT prime
+ :mod:`agent_cache` because the cache key requires real
+ ``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the
+ throwaway agent is genuinely thrown away and immediately collected.
+
+ Safety
+ ------
+ * No DB access. We construct a stub LLM (no real keys), pass an
+ empty tools list, and pass ``checkpointer=None`` so we never
+ touch Postgres.
+ * Bounded by ``asyncio.wait_for`` so a hang here can never block
+ worker startup. On any failure, we log + swallow — the worst
+ case is the first real request pays the full cold cost (i.e.
+ pre-warmup behaviour).
+ """
+ import time as _time
+
+ logger = logging.getLogger(__name__)
+ t0 = _time.perf_counter()
+ try:
+ from langchain.agents import create_agent
+ from langchain.agents.middleware import (
+ ModelCallLimitMiddleware,
+ TodoListMiddleware,
+ ToolCallLimitMiddleware,
+ )
+ from langchain_core.language_models.fake_chat_models import (
+ FakeListChatModel,
+ )
+ from langchain_core.tools import tool
+
+ from app.agents.new_chat.context import SurfSenseContextSchema
+
+ # Minimal LLM stub. ``FakeListChatModel`` satisfies
+ # ``BaseChatModel`` without any network or auth — perfect for
+ # exercising the compile path without side effects.
+ stub_llm = FakeListChatModel(responses=["warmup-response"])
+
+ # Two trivial tools with arg + return schemas — exercises the
+ # Pydantic v2 schema JIT path. Without at least one tool the
+ # graph compile skips the tool-loop bytecode generation that
+ # accounts for ~30-50% of cold compile cost.
+ @tool
+ def _warmup_tool_a(query: str, limit: int = 5) -> str:
+ """Warmup tool A — never actually invoked."""
+ return query[:limit]
+
+ @tool
+ def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
+ """Warmup tool B — never actually invoked."""
+ return {"name": name, "value": value}
+
+ # A handful of common middleware so the compile pre-pays the
+ # ``AgentMiddleware`` resolver path. These instances never run
+ # because the throwaway agent is immediately collected.
+ # ``SubAgentMiddleware`` is the single heaviest line in cold
+ # ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
+ # compile its general-purpose subagent's full inner graph),
+ # so we include it here to make sure that compile path is JIT'd.
+ warmup_middleware: list = [
+ TodoListMiddleware(),
+ ModelCallLimitMiddleware(
+ thread_limit=120, run_limit=80, exit_behavior="end"
+ ),
+ ToolCallLimitMiddleware(
+ thread_limit=300, run_limit=80, exit_behavior="continue"
+ ),
+ ]
+ try:
+ from deepagents import SubAgentMiddleware
+ from deepagents.backends import StateBackend
+ from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
+
+ gp_warmup_spec = { # type: ignore[var-annotated]
+ **GENERAL_PURPOSE_SUBAGENT,
+ "model": stub_llm,
+ "tools": [_warmup_tool_a],
+ "middleware": [TodoListMiddleware()],
+ }
+ warmup_middleware.append(
+ SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
+ )
+ except Exception:
+ # Deepagents missing/incompatible — middleware-only warmup
+ # still produces a useful (smaller) speedup.
+ logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
+
+ compiled = create_agent(
+ stub_llm,
+ tools=[_warmup_tool_a, _warmup_tool_b],
+ system_prompt="You are a warmup stub.",
+ middleware=warmup_middleware,
+ context_schema=SurfSenseContextSchema,
+ checkpointer=None,
+ )
+
+ # Touch the compiled graph's stream_channels / nodes so any
+ # remaining lazy schema work fires now instead of on first
+ # real invocation.
+ _ = list(getattr(compiled, "nodes", {}).keys())
+
+ del compiled
+ logger.info(
+ "[startup] Agent JIT warmup completed in %.3fs",
+ _time.perf_counter() - t0,
+ )
+ except Exception:
+ logger.warning(
+ "[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
+ "real request will pay the full compile cost)",
+ _time.perf_counter() - t0,
+ exc_info=True,
+ )
+
+
@asynccontextmanager
async def lifespan(app: FastAPI):
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
@@ -445,6 +574,18 @@ async def lifespan(app: FastAPI):
"Docs will be indexed on the next restart."
)
+ # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
+ # worker readiness. ``shield`` so Uvicorn cancelling startup
+ # doesn't leave half-warmed Pydantic schemas in an inconsistent
+ # state.
+ try:
+ await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
+ except (TimeoutError, Exception): # pragma: no cover - defensive
+ logging.getLogger(__name__).warning(
+ "[startup] Agent JIT warmup hit timeout/error — skipping; "
+ "first real request will pay the full compile cost."
+ )
+
log_system_snapshot("startup_complete")
yield
diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py
index 7c55da2e5..45bcfd00f 100644
--- a/surfsense_backend/app/services/connector_service.py
+++ b/surfsense_backend/app/services/connector_service.py
@@ -1,6 +1,8 @@
import asyncio
+import os
import time
from datetime import datetime
+from threading import Lock
from typing import Any
import httpx
@@ -2769,12 +2771,22 @@ class ConnectorService:
"""
Get all available (enabled) connector types for a search space.
+ Phase 1.4: results are cached per ``search_space_id`` for
+ :data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
+ identity — the cached value is plain data, safe to share across
+ requests. Invalidate on connector add/update/delete via
+ :func:`invalidate_connector_discovery_cache`.
+
Args:
search_space_id: The search space ID
Returns:
List of SearchSourceConnectorType enums for enabled connectors
"""
+ cached = _get_cached_connectors(search_space_id)
+ if cached is not None:
+ return list(cached)
+
query = (
select(SearchSourceConnector.connector_type)
.filter(
@@ -2784,8 +2796,9 @@ class ConnectorService:
)
result = await self.session.execute(query)
- connector_types = result.scalars().all()
- return list(connector_types)
+ connector_types = list(result.scalars().all())
+ _set_cached_connectors(search_space_id, connector_types)
+ return connector_types
async def get_available_document_types(
self,
@@ -2794,12 +2807,22 @@ class ConnectorService:
"""
Get all document types that have at least one document in the search space.
+ Phase 1.4: cached per ``search_space_id`` for
+ :data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
+ :func:`invalidate_connector_discovery_cache` when a connector
+ finishes indexing new documents (or document types are otherwise
+ added/removed).
+
Args:
search_space_id: The search space ID
Returns:
List of document type strings that have documents indexed
"""
+ cached = _get_cached_doc_types(search_space_id)
+ if cached is not None:
+ return list(cached)
+
from sqlalchemy import distinct
from app.db import Document
@@ -2809,5 +2832,164 @@ class ConnectorService:
)
result = await self.session.execute(query)
- doc_types = result.scalars().all()
- return [str(dt) for dt in doc_types]
+ doc_types = [str(dt) for dt in result.scalars().all()]
+ _set_cached_doc_types(search_space_id, doc_types)
+ return doc_types
+
+
+# ---------------------------------------------------------------------------
+# Connector / document-type discovery TTL cache (Phase 1.4)
+# ---------------------------------------------------------------------------
+#
+# Both ``get_available_connectors`` and ``get_available_document_types`` are
+# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
+# hits Postgres and contributes to per-turn agent build latency. Their
+# results change infrequently — only when the user adds/edits/removes a
+# connector, or when an indexer commits a new document type. A short TTL
+# cache (default 30s, env-tunable) collapses N concurrent calls into one
+# DB roundtrip with bounded staleness.
+#
+# Invalidation: connector mutation routes (create / update / delete) call
+# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
+# entry for the affected space. Multi-replica deployments still pay one
+# DB roundtrip per replica per TTL window, which is fine — staleness is
+# bounded and the alternative (cross-replica fanout) is not worth the
+# coupling here.
+
+_DISCOVERY_TTL_SECONDS: float = float(
+ os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
+)
+
+# Per-search-space caches. Keyed by ``search_space_id``; value is
+# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
+# read-mostly workload, sub-microsecond contention.
+_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
+_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
+_cache_lock = Lock()
+
+
+def _get_cached_connectors(
+ search_space_id: int,
+) -> list[SearchSourceConnectorType] | None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return None
+ with _cache_lock:
+ entry = _connectors_cache.get(search_space_id)
+ if entry is None:
+ return None
+ expires_at, payload = entry
+ if time.monotonic() >= expires_at:
+ _connectors_cache.pop(search_space_id, None)
+ return None
+ return payload
+
+
+def _set_cached_connectors(
+ search_space_id: int, payload: list[SearchSourceConnectorType]
+) -> None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return
+ expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
+ with _cache_lock:
+ _connectors_cache[search_space_id] = (expires_at, list(payload))
+
+
+def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return None
+ with _cache_lock:
+ entry = _doc_types_cache.get(search_space_id)
+ if entry is None:
+ return None
+ expires_at, payload = entry
+ if time.monotonic() >= expires_at:
+ _doc_types_cache.pop(search_space_id, None)
+ return None
+ return payload
+
+
+def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
+ if _DISCOVERY_TTL_SECONDS <= 0:
+ return
+ expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
+ with _cache_lock:
+ _doc_types_cache[search_space_id] = (expires_at, list(payload))
+
+
+def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
+ """Drop cached discovery results for ``search_space_id`` (or all spaces).
+
+ Connector CRUD routes / indexer pipelines call this when they mutate
+ the rows backing :func:`ConnectorService.get_available_connectors` /
+ :func:`get_available_document_types`. ``None`` clears every space —
+ useful in tests and on bulk imports.
+ """
+ with _cache_lock:
+ if search_space_id is None:
+ _connectors_cache.clear()
+ _doc_types_cache.clear()
+ else:
+ _connectors_cache.pop(search_space_id, None)
+ _doc_types_cache.pop(search_space_id, None)
+
+
+def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
+ with _cache_lock:
+ if search_space_id is None:
+ _connectors_cache.clear()
+ else:
+ _connectors_cache.pop(search_space_id, None)
+
+
+def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
+ with _cache_lock:
+ if search_space_id is None:
+ _doc_types_cache.clear()
+ else:
+ _doc_types_cache.pop(search_space_id, None)
+
+
+def _register_invalidation_listeners() -> None:
+ """Wire SQLAlchemy ORM events so cache stays consistent automatically.
+
+ Listening on ``after_insert`` / ``after_update`` / ``after_delete``
+ means every successful INSERT/UPDATE/DELETE that goes through the ORM
+ invalidates the affected search space's cached discovery payload —
+ no need to sprinkle ``invalidate_*`` calls across 30+ connector
+ routes. Bulk operations that bypass the ORM (e.g.
+ ``session.execute(insert(...))`` without a mapped object) still need
+ explicit invalidation; document indexers already commit through the
+ ORM so document-type discovery is covered.
+ """
+ from sqlalchemy import event
+
+ # Imported here (not at module top) to avoid a circular import:
+ # app.services.connector_service is itself imported from app.db's
+ # ecosystem indirectly via several CRUD modules.
+ from app.db import Document, SearchSourceConnector
+
+ def _connector_changed(_mapper, _connection, target) -> None:
+ sid = getattr(target, "search_space_id", None)
+ if sid is not None:
+ _invalidate_connectors_only(int(sid))
+
+ def _document_changed(_mapper, _connection, target) -> None:
+ sid = getattr(target, "search_space_id", None)
+ if sid is not None:
+ _invalidate_doc_types_only(int(sid))
+
+ for evt in ("after_insert", "after_update", "after_delete"):
+ event.listen(SearchSourceConnector, evt, _connector_changed)
+ event.listen(Document, evt, _document_changed)
+
+
+try:
+ _register_invalidation_listeners()
+except Exception: # pragma: no cover - defensive; never block module import
+ import logging as _logging
+
+ _logging.getLogger(__name__).exception(
+ "Failed to register connector discovery cache invalidation listeners; "
+ "stale cache risk: explicit invalidate_connector_discovery_cache calls "
+ "may be required."
+ )
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index 268a4401e..f7ddd8909 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
+from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
@@ -559,6 +560,29 @@ async def _preflight_llm(llm: Any) -> None:
)
+async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
+ """Wait for a discarded speculative agent build to release shared state.
+
+ Used by the parallel preflight + agent-build path. The speculative build
+ closes over the request-scoped ``AsyncSession`` (for the brief connector
+ discovery / tool-factory window before its CPU work moves into a worker
+ thread). If preflight reports a 429 we want to fall back to the original
+ repin → reload → rebuild path, but we MUST NOT touch ``session`` again
+ until any in-flight session work owned by the speculative build has
+ fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
+ concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
+ earlier in this PR (see ``connector_service`` parallel-gather revert).
+
+ We simply ``await`` the task and swallow any exception: in this path the
+ build's outcome is irrelevant — success populates the agent cache (a free
+ side effect), failure is discarded. The wasted CPU is acceptable since
+ 429 fallbacks are rare and the original sequential code also paid the
+ full build cost on the same path.
+ """
+ with contextlib.suppress(BaseException):
+ await task
+
+
def _classify_stream_exception(
exc: Exception,
*,
@@ -696,6 +720,7 @@ async def _stream_agent_events(
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
+ runtime_context: Any = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@@ -801,7 +826,18 @@ async def _stream_agent_events(
return event
return None
- async for event in agent.astream_events(input_data, config=config, version="v2"):
+ # Per-invocation runtime context (Phase 1.5). When supplied,
+ # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids``
+ # from ``runtime.context`` instead of its constructor closure — the
+ # prerequisite that lets the compiled-agent cache (Phase 1) reuse a
+ # single graph across turns. Astream_events_kwargs stays empty when
+ # callers leave ``runtime_context`` as ``None`` to preserve the
+ # legacy code path bit-for-bit.
+ astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
+ if runtime_context is not None:
+ astream_kwargs["context"] = runtime_context
+
+ async for event in agent.astream_events(input_data, **astream_kwargs):
event_type = event.get("event", "")
if event_type == "on_chat_model_stream":
@@ -2560,23 +2596,102 @@ async def stream_new_chat(
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
# title-generation LLM calls fan out and each independently hit the
# same upstream rate limit.
- if (
+ #
+ # PERF: preflight is a network round-trip to the LLM provider (~1-5s)
+ # and is independent of the agent build (CPU-bound, ~5-7s). They used
+ # to run sequentially → ``preflight + build`` on cold cache = 11.5s.
+ # We now kick off preflight as a background task FIRST, then run the
+ # synchronous setup work and the agent build in parallel. In the
+ # success path (the common case) total wall time drops to roughly
+ # ``max(preflight, build)`` — the preflight finishes during the
+ # agent compile and we just consume its result. In the rare 429
+ # path the speculative build is awaited to completion (so its
+ # session usage is fully released) via
+ # :func:`_settle_speculative_agent_build`, then discarded, and
+ # we fall back to the original repin-and-rebuild flow.
+ preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
- ):
+ )
+ preflight_task: asyncio.Task[None] | None = None
+ _t_preflight = 0.0
+ if preflight_needed:
_t_preflight = time.perf_counter()
+ preflight_task = asyncio.create_task(
+ _preflight_llm(llm),
+ name=f"auto_pin_preflight:{llm_config_id}",
+ )
+
+ # Create connector service
+ _t0 = time.perf_counter()
+ connector_service = ConnectorService(session, search_space_id=search_space_id)
+
+ firecrawl_api_key = None
+ webcrawler_connector = await connector_service.get_connector_by_type(
+ SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
+ )
+ if webcrawler_connector and webcrawler_connector.config:
+ firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
+ _perf_log.info(
+ "[stream_new_chat] Connector service + firecrawl key in %.3fs",
+ time.perf_counter() - _t0,
+ )
+
+ # Get the PostgreSQL checkpointer for persistent conversation memory
+ _t0 = time.perf_counter()
+ checkpointer = await get_checkpointer()
+ _perf_log.info(
+ "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
+ )
+
+ visibility = thread_visibility or ChatVisibility.PRIVATE
+ _t0 = time.perf_counter()
+ # Speculative agent build — runs in parallel with the preflight
+ # task (if any). Built with the *current* ``llm`` / ``agent_config``;
+ # if preflight reports 429 we will discard this future and rebuild
+ # against the freshly pinned config below.
+ agent_build_task = asyncio.create_task(
+ create_surfsense_deep_agent(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ disabled_tools=disabled_tools,
+ mentioned_document_ids=mentioned_document_ids,
+ filesystem_selection=filesystem_selection,
+ ),
+ name="agent_build:stream_new_chat",
+ )
+
+ agent: Any = None
+ if preflight_task is not None:
try:
- await _preflight_llm(llm)
+ await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
- "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs",
+ "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
+ # Both branches below need the session: the non-429 path
+ # may unwind via cleanup that uses ``session``, and the
+ # 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
+ # against it. Wait for the speculative build to release its
+ # session usage before we proceed.
+ await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
+ # 429: speculative agent is discarded; run the original
+ # repin → reload → rebuild path against the freshly
+ # pinned config.
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
@@ -2639,46 +2754,28 @@ async def stream_new_chat(
"fallback_config_id": llm_config_id,
},
)
+ # Rebuild against the new llm/agent_config. Sequential
+ # here because we no longer have anything to overlap with.
+ agent = await create_surfsense_deep_agent(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ disabled_tools=disabled_tools,
+ mentioned_document_ids=mentioned_document_ids,
+ filesystem_selection=filesystem_selection,
+ )
- # Create connector service
- _t0 = time.perf_counter()
- connector_service = ConnectorService(session, search_space_id=search_space_id)
-
- firecrawl_api_key = None
- webcrawler_connector = await connector_service.get_connector_by_type(
- SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
- )
- if webcrawler_connector and webcrawler_connector.config:
- firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
- _perf_log.info(
- "[stream_new_chat] Connector service + firecrawl key in %.3fs",
- time.perf_counter() - _t0,
- )
-
- # Get the PostgreSQL checkpointer for persistent conversation memory
- _t0 = time.perf_counter()
- checkpointer = await get_checkpointer()
- _perf_log.info(
- "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
- )
-
- visibility = thread_visibility or ChatVisibility.PRIVATE
- _t0 = time.perf_counter()
- agent = await create_surfsense_deep_agent(
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- disabled_tools=disabled_tools,
- mentioned_document_ids=mentioned_document_ids,
- filesystem_selection=filesystem_selection,
- )
+ if agent is None:
+ # Either no preflight was needed, or preflight succeeded —
+ # in both cases the speculative build is the agent we want.
+ agent = await agent_build_task
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
@@ -3005,6 +3102,18 @@ async def stream_new_chat(
title_emitted = False
+ # Build the per-invocation runtime context (Phase 1.5).
+ # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware``
+ # via ``runtime.context.mentioned_document_ids`` instead of its
+ # ``__init__`` closure — that way the same compiled-agent instance
+ # can serve multiple turns with different mention lists.
+ runtime_context = SurfSenseContextSchema(
+ search_space_id=search_space_id,
+ mentioned_document_ids=list(mentioned_document_ids or []),
+ request_id=request_id,
+ turn_id=stream_result.turn_id,
+ )
+
_t_stream_start = time.perf_counter()
_first_event_logged = False
runtime_rate_limit_recovered = False
@@ -3028,6 +3137,7 @@ async def stream_new_chat(
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
+ runtime_context=runtime_context,
):
if not _first_event_logged:
_perf_log.info(
@@ -3643,21 +3753,75 @@ async def stream_resume_chat(
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
# one cheap probe before the agent is rebuilt so a 429'd pin gets
# repinned without burning planner/classifier/title calls first.
- if (
+ # See ``stream_new_chat`` for the full rationale on the speculative
+ # parallel build pattern below.
+ preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
- ):
+ )
+ preflight_task: asyncio.Task[None] | None = None
+ _t_preflight = 0.0
+ if preflight_needed:
_t_preflight = time.perf_counter()
+ preflight_task = asyncio.create_task(
+ _preflight_llm(llm),
+ name=f"auto_pin_preflight_resume:{llm_config_id}",
+ )
+
+ _t0 = time.perf_counter()
+ connector_service = ConnectorService(session, search_space_id=search_space_id)
+
+ firecrawl_api_key = None
+ webcrawler_connector = await connector_service.get_connector_by_type(
+ SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
+ )
+ if webcrawler_connector and webcrawler_connector.config:
+ firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
+ _perf_log.info(
+ "[stream_resume] Connector service + firecrawl key in %.3fs",
+ time.perf_counter() - _t0,
+ )
+
+ _t0 = time.perf_counter()
+ checkpointer = await get_checkpointer()
+ _perf_log.info(
+ "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
+ )
+
+ visibility = thread_visibility or ChatVisibility.PRIVATE
+
+ _t0 = time.perf_counter()
+ agent_build_task = asyncio.create_task(
+ create_surfsense_deep_agent(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ filesystem_selection=filesystem_selection,
+ ),
+ name="agent_build:stream_resume",
+ )
+
+ agent: Any = None
+ if preflight_task is not None:
try:
- await _preflight_llm(llm)
+ await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
- "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs",
+ "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
+ # Same session-safety rationale as ``stream_new_chat``.
+ await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
previous_config_id = llm_config_id
@@ -3717,43 +3881,22 @@ async def stream_resume_chat(
"fallback_config_id": llm_config_id,
},
)
+ agent = await create_surfsense_deep_agent(
+ llm=llm,
+ search_space_id=search_space_id,
+ db_session=session,
+ connector_service=connector_service,
+ checkpointer=checkpointer,
+ user_id=user_id,
+ thread_id=chat_id,
+ agent_config=agent_config,
+ firecrawl_api_key=firecrawl_api_key,
+ thread_visibility=visibility,
+ filesystem_selection=filesystem_selection,
+ )
- _t0 = time.perf_counter()
- connector_service = ConnectorService(session, search_space_id=search_space_id)
-
- firecrawl_api_key = None
- webcrawler_connector = await connector_service.get_connector_by_type(
- SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
- )
- if webcrawler_connector and webcrawler_connector.config:
- firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
- _perf_log.info(
- "[stream_resume] Connector service + firecrawl key in %.3fs",
- time.perf_counter() - _t0,
- )
-
- _t0 = time.perf_counter()
- checkpointer = await get_checkpointer()
- _perf_log.info(
- "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
- )
-
- visibility = thread_visibility or ChatVisibility.PRIVATE
-
- _t0 = time.perf_counter()
- agent = await create_surfsense_deep_agent(
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- )
+ if agent is None:
+ agent = await agent_build_task
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
@@ -3794,6 +3937,16 @@ async def stream_resume_chat(
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
+ # Resume path doesn't carry new ``mentioned_document_ids`` —
+ # those are seeded in the original turn. We still pass a
+ # context so future middleware extensions (Phase 2) can rely on
+ # ``runtime.context`` always being populated.
+ runtime_context = SurfSenseContextSchema(
+ search_space_id=search_space_id,
+ request_id=request_id,
+ turn_id=stream_result.turn_id,
+ )
+
_t_stream_start = time.perf_counter()
_first_event_logged = False
runtime_rate_limit_recovered = False
@@ -3814,6 +3967,7 @@ async def stream_resume_chat(
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
+ runtime_context=runtime_context,
):
if not _first_event_logged:
_perf_log.info(
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
new file mode 100644
index 000000000..9b3de2db7
--- /dev/null
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
@@ -0,0 +1,268 @@
+"""Regression tests for the compiled-agent cache.
+
+Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
+build-failure non-caching) and the cache-key signature helpers that
+``create_surfsense_deep_agent`` relies on. The integration with
+``create_surfsense_deep_agent`` is covered separately by the streaming
+contract tests; this module focuses on the primitives so a regression
+in the cache implementation is caught before it reaches the agent
+factory.
+"""
+
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+
+import pytest
+
+from app.agents.new_chat.agent_cache import (
+ flags_signature,
+ reload_for_tests,
+ stable_hash,
+ system_prompt_hash,
+ tools_signature,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# stable_hash + signature helpers
+# ---------------------------------------------------------------------------
+
+
+def test_stable_hash_is_deterministic_across_calls() -> None:
+ a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
+ b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
+ assert a == b
+
+
+def test_stable_hash_changes_when_any_part_changes() -> None:
+ base = stable_hash("v1", 42, "thread-9")
+ assert stable_hash("v1", 42, "thread-10") != base
+ assert stable_hash("v2", 42, "thread-9") != base
+ assert stable_hash("v1", 43, "thread-9") != base
+
+
+def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
+ """Two tool lists with the same surface must hash identically.
+
+ The cache key MUST NOT change when the underlying ``BaseTool``
+ instances are different Python objects (a fresh request constructs
+ fresh tool instances every time). Hashing on ``(name, description)``
+ keeps the cache hot across requests with identical tool surfaces.
+ """
+
+ @dataclass
+ class FakeTool:
+ name: str
+ description: str
+
+ tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
+ tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
+ sig_a = tools_signature(
+ tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ sig_b = tools_signature(
+ tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ assert sig_a == sig_b, "tool order must not affect the signature"
+
+ # Adding a tool rotates the key.
+ tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
+ sig_c = tools_signature(
+ tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ assert sig_c != sig_a
+
+
+def test_tools_signature_rotates_when_connector_set_changes() -> None:
+ @dataclass
+ class FakeTool:
+ name: str
+ description: str
+
+ tools = [FakeTool("a", "x")]
+ base = tools_signature(
+ tools, available_connectors=["NOTION"], available_document_types=["FILE"]
+ )
+ added = tools_signature(
+ tools,
+ available_connectors=["NOTION", "SLACK"],
+ available_document_types=["FILE"],
+ )
+ assert base != added, "adding a connector must rotate the cache key"
+
+
+def test_flags_signature_changes_when_flag_flips() -> None:
+ @dataclass(frozen=True)
+ class Flags:
+ a: bool = True
+ b: bool = False
+
+ base = flags_signature(Flags())
+ flipped = flags_signature(Flags(b=True))
+ assert base != flipped
+
+
+def test_system_prompt_hash_is_stable_and_distinct() -> None:
+ p1 = "You are a helpful assistant."
+ p2 = "You are a helpful assistant!" # one-character delta
+ assert system_prompt_hash(p1) == system_prompt_hash(p1)
+ assert system_prompt_hash(p1) != system_prompt_hash(p2)
+
+
+# ---------------------------------------------------------------------------
+# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_cache_hit_returns_same_instance_on_second_call() -> None:
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+ builds = 0
+
+ async def builder() -> object:
+ nonlocal builds
+ builds += 1
+ return object()
+
+ a = await cache.get_or_build("k", builder=builder)
+ b = await cache.get_or_build("k", builder=builder)
+ assert a is b, "cache must return the SAME object across hits"
+ assert builds == 1, "builder must run exactly once"
+
+
+@pytest.mark.asyncio
+async def test_cache_different_keys_get_different_instances() -> None:
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ a = await cache.get_or_build("k1", builder=builder)
+ b = await cache.get_or_build("k2", builder=builder)
+ assert a is not b
+
+
+@pytest.mark.asyncio
+async def test_cache_stale_entries_get_rebuilt() -> None:
+ # ttl=0 means every read sees the entry as immediately stale.
+ cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
+ builds = 0
+
+ async def builder() -> object:
+ nonlocal builds
+ builds += 1
+ return object()
+
+ a = await cache.get_or_build("k", builder=builder)
+ b = await cache.get_or_build("k", builder=builder)
+ assert a is not b, "stale entry must rebuild a fresh instance"
+ assert builds == 2
+
+
+@pytest.mark.asyncio
+async def test_cache_evicts_lru_when_full() -> None:
+ cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ a = await cache.get_or_build("a", builder=builder)
+ _ = await cache.get_or_build("b", builder=builder)
+ # Re-touch "a" so "b" is now the LRU victim.
+ a_again = await cache.get_or_build("a", builder=builder)
+ assert a_again is a
+ # Inserting "c" should evict "b" (LRU), not "a".
+ _ = await cache.get_or_build("c", builder=builder)
+ assert cache.stats()["size"] == 2
+
+ # Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
+ a_hit = await cache.get_or_build("a", builder=builder)
+ assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
+
+
+@pytest.mark.asyncio
+async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
+ """Two concurrent get_or_build calls on the same key must share one builder."""
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+ build_started = asyncio.Event()
+ builds = 0
+
+ async def slow_builder() -> object:
+ nonlocal builds
+ builds += 1
+ build_started.set()
+ # Yield control so the second waiter can race against us.
+ await asyncio.sleep(0.05)
+ return object()
+
+ task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
+ # Wait until the first builder has started, then race a second waiter.
+ await build_started.wait()
+ task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
+
+ a, b = await asyncio.gather(task_a, task_b)
+ assert a is b, "coalesced waiters must observe the same value"
+ assert builds == 1, "concurrent cold misses must collapse to ONE build"
+
+
+@pytest.mark.asyncio
+async def test_cache_does_not_store_failed_builds() -> None:
+ """A builder that raises must NOT poison the cache.
+
+ The next caller for the same key must run the builder again (not
+ re-raise the cached exception).
+ """
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+ attempts = 0
+
+ async def flaky_builder() -> object:
+ nonlocal attempts
+ attempts += 1
+ if attempts == 1:
+ raise RuntimeError("transient")
+ return object()
+
+ with pytest.raises(RuntimeError, match="transient"):
+ await cache.get_or_build("k", builder=flaky_builder)
+
+ # Second call must retry — not re-raise the cached exception.
+ value = await cache.get_or_build("k", builder=flaky_builder)
+ assert value is not None
+ assert attempts == 2
+
+
+@pytest.mark.asyncio
+async def test_cache_invalidate_drops_entry() -> None:
+ cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ a = await cache.get_or_build("k", builder=builder)
+ assert cache.invalidate("k") is True
+ b = await cache.get_or_build("k", builder=builder)
+ assert a is not b, "post-invalidation lookup must rebuild"
+
+
+@pytest.mark.asyncio
+async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
+ cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
+
+ async def builder() -> object:
+ return object()
+
+ await cache.get_or_build("user:1:thread:1", builder=builder)
+ await cache.get_or_build("user:1:thread:2", builder=builder)
+ await cache.get_or_build("user:2:thread:1", builder=builder)
+
+ removed = cache.invalidate_prefix("user:1:")
+ assert removed == 2
+ assert cache.stats()["size"] == 1
+
+ # The user:2 entry must still be hot (no rebuild).
+ survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
+ assert survivor_value is not None
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
index df60a4816..6800be2af 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py
@@ -34,6 +34,8 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
+ "SURFSENSE_ENABLE_AGENT_CACHE",
+ "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]:
monkeypatch.delenv(name, raising=False)
@@ -62,6 +64,11 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
+ # Phase 2: agent cache is now default-on (the prerequisite tool
+ # ``db_session`` refactor landed). The companion gp-subagent share
+ # flag stays default-off pending data on cold-miss frequency.
+ assert flags.enable_agent_cache is True
+ assert flags.enable_agent_cache_share_gp_subagent is False
assert flags.any_new_middleware_enabled() is True
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py
new file mode 100644
index 000000000..6c323d920
--- /dev/null
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py
@@ -0,0 +1,344 @@
+"""Tests for ``FlattenSystemMessageMiddleware``.
+
+The middleware exists to defend against Anthropic's "Found 5 cache_control
+blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
+the system message and the OpenRouter→Anthropic adapter redistributes
+``cache_control`` across all of them. The flattening collapses every
+all-text system content list to a single string before the LLM call.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+from langchain_core.messages import HumanMessage, SystemMessage
+
+from app.agents.new_chat.middleware.flatten_system import (
+ FlattenSystemMessageMiddleware,
+ _flatten_text_blocks,
+ _flattened_request,
+)
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# _flatten_text_blocks — pure helper, the heart of the middleware.
+# ---------------------------------------------------------------------------
+
+
+class TestFlattenTextBlocks:
+ def test_joins_text_blocks_with_double_newline(self) -> None:
+ blocks = [
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ ]
+ assert (
+ _flatten_text_blocks(blocks)
+ == "\n\n\n\n"
+ )
+
+ def test_handles_single_text_block(self) -> None:
+ blocks = [{"type": "text", "text": "only one"}]
+ assert _flatten_text_blocks(blocks) == "only one"
+
+ def test_handles_empty_list(self) -> None:
+ assert _flatten_text_blocks([]) == ""
+
+ def test_passes_through_bare_string_blocks(self) -> None:
+ # LangChain content can mix bare strings and dict blocks.
+ blocks = ["raw string", {"type": "text", "text": "dict block"}]
+ assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
+
+ def test_returns_none_for_image_block(self) -> None:
+ # System messages with images are rare — but we never want to
+ # silently lose the image payload by joining as text.
+ blocks = [
+ {"type": "text", "text": "look at this"},
+ {"type": "image_url", "image_url": {"url": "data:image/png..."}},
+ ]
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_returns_none_for_non_dict_non_str_block(self) -> None:
+ blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_returns_none_when_text_field_missing(self) -> None:
+ blocks = [{"type": "text"}] # no ``text`` key
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_returns_none_when_text_is_not_string(self) -> None:
+ blocks = [{"type": "text", "text": ["nested", "list"]}]
+ assert _flatten_text_blocks(blocks) is None
+
+ def test_drops_cache_control_from_inner_blocks(self) -> None:
+ # The whole point: existing cache_control on inner blocks is
+ # discarded so LiteLLM's ``cache_control_injection_points`` can
+ # re-attach exactly one breakpoint after flattening.
+ blocks = [
+ {"type": "text", "text": "first"},
+ {
+ "type": "text",
+ "text": "second",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]
+ flattened = _flatten_text_blocks(blocks)
+ assert flattened == "first\n\nsecond"
+ assert "cache_control" not in flattened # type: ignore[operator]
+
+
+# ---------------------------------------------------------------------------
+# _flattened_request — decides when to override and when to no-op.
+# ---------------------------------------------------------------------------
+
+
+def _make_request(system_message: SystemMessage | None) -> Any:
+ """Build a minimal ModelRequest stub. We only need .system_message
+ and .override(system_message=...) — the middleware never touches
+ other fields.
+ """
+ request = MagicMock()
+ request.system_message = system_message
+
+ def override(**kwargs: Any) -> Any:
+ new_request = MagicMock()
+ new_request.system_message = kwargs.get(
+ "system_message", request.system_message
+ )
+ new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
+ new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
+ return new_request
+
+ request.override = override
+ return request
+
+
+class TestFlattenedRequest:
+ def test_collapses_multi_block_system_to_string(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": " "},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ ]
+ )
+ request = _make_request(sys)
+ flattened = _flattened_request(request)
+
+ assert flattened is not None
+ assert isinstance(flattened.system_message, SystemMessage)
+ assert flattened.system_message.content == (
+ " \n\n\n\n\n\n\n\n"
+ )
+
+ def test_no_op_for_string_content(self) -> None:
+ sys = SystemMessage(content="already a string")
+ request = _make_request(sys)
+ assert _flattened_request(request) is None
+
+ def test_no_op_for_single_block_list(self) -> None:
+ # One block already produces one breakpoint — no need to flatten.
+ sys = SystemMessage(content=[{"type": "text", "text": "single"}])
+ request = _make_request(sys)
+ assert _flattened_request(request) is None
+
+ def test_no_op_when_system_message_missing(self) -> None:
+ request = _make_request(None)
+ assert _flattened_request(request) is None
+
+ def test_no_op_when_list_contains_non_text_block(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "look"},
+ {"type": "image_url", "image_url": {"url": "data:..."}},
+ ]
+ )
+ request = _make_request(sys)
+ assert _flattened_request(request) is None
+
+ def test_preserves_additional_kwargs_and_metadata(self) -> None:
+ # Defensive: nothing in the current chain sets these on a system
+ # message, but losing them silently when something does in the
+ # future would be a regression. ``name`` in particular is the only
+ # ``additional_kwargs`` field that ChatLiteLLM's
+ # ``_convert_message_to_dict`` propagates onto the wire.
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "a"},
+ {"type": "text", "text": "b"},
+ ],
+ additional_kwargs={"name": "surfsense_system", "x": 1},
+ response_metadata={"tokens": 42},
+ )
+ sys.id = "sys-msg-1"
+ request = _make_request(sys)
+
+ flattened = _flattened_request(request)
+ assert flattened is not None
+ assert flattened.system_message.content == "a\n\nb"
+ assert flattened.system_message.additional_kwargs == {
+ "name": "surfsense_system",
+ "x": 1,
+ }
+ assert flattened.system_message.response_metadata == {"tokens": 42}
+ assert flattened.system_message.id == "sys-msg-1"
+
+ def test_idempotent_when_run_twice(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "a"},
+ {"type": "text", "text": "b"},
+ ]
+ )
+ request = _make_request(sys)
+ first = _flattened_request(request)
+ assert first is not None
+
+ # Second pass on the already-flattened request should be a no-op.
+ # We re-wrap in a request stub since the helper inspects
+ # ``request.system_message.content``.
+ second_request = _make_request(first.system_message)
+ assert _flattened_request(second_request) is None
+
+
+# ---------------------------------------------------------------------------
+# Middleware integration — verify the handler sees a flattened request.
+# ---------------------------------------------------------------------------
+
+
+class TestMiddlewareWrap:
+ @pytest.mark.asyncio
+ async def test_async_passes_flattened_request_to_handler(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "alpha"},
+ {"type": "text", "text": "beta"},
+ ]
+ )
+ request = _make_request(sys)
+ captured: dict[str, Any] = {}
+
+ async def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ result = await mw.awrap_model_call(request, handler)
+
+ assert result == "ok"
+ assert isinstance(captured["request"].system_message, SystemMessage)
+ assert captured["request"].system_message.content == "alpha\n\nbeta"
+
+ @pytest.mark.asyncio
+ async def test_async_passes_through_when_already_string(self) -> None:
+ sys = SystemMessage(content="just a string")
+ request = _make_request(sys)
+ captured: dict[str, Any] = {}
+
+ async def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ await mw.awrap_model_call(request, handler)
+
+ # Same request object: no override happened.
+ assert captured["request"] is request
+
+ def test_sync_passes_flattened_request_to_handler(self) -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "alpha"},
+ {"type": "text", "text": "beta"},
+ ]
+ )
+ request = _make_request(sys)
+ captured: dict[str, Any] = {}
+
+ def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ result = mw.wrap_model_call(request, handler)
+
+ assert result == "ok"
+ assert captured["request"].system_message.content == "alpha\n\nbeta"
+
+ def test_sync_passes_through_when_no_system_message(self) -> None:
+ request = _make_request(None)
+ captured: dict[str, Any] = {}
+
+ def handler(req: Any) -> str:
+ captured["request"] = req
+ return "ok"
+
+ mw = FlattenSystemMessageMiddleware()
+ mw.wrap_model_call(request, handler)
+ assert captured["request"] is request
+
+
+# ---------------------------------------------------------------------------
+# Regression guard — pin the worst-case shape that triggered the
+# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
+# downstream cache_control_injection_points can only place 1 breakpoint
+# on the system message regardless of provider redistribution quirks.
+# ---------------------------------------------------------------------------
+
+
+def test_regression_five_block_system_collapses_to_one_block() -> None:
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ {"type": "text", "text": ""},
+ ]
+ )
+ request = _make_request(sys)
+ flattened = _flattened_request(request)
+
+ assert flattened is not None
+ assert isinstance(flattened.system_message.content, str)
+ # The exact join doesn't matter for the cache_control accounting —
+ # only that there is exactly ONE content block when LiteLLM's
+ # AnthropicCacheControlHook later targets ``role: system``.
+ assert " None:
+ # Sanity: the middleware MUST NOT touch user messages — only the
+ # system message. Multi-block user content is the path that carries
+ # image attachments and would lose its image_url block on
+ # accidental flatten.
+ sys = SystemMessage(
+ content=[
+ {"type": "text", "text": "a"},
+ {"type": "text", "text": "b"},
+ ]
+ )
+ user = HumanMessage(
+ content=[
+ {"type": "text", "text": "look at this"},
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
+ ]
+ )
+ request = _make_request(sys)
+ request.messages = [user]
+
+ flattened = _flattened_request(request)
+ assert flattened is not None
+ # System flattened to string …
+ assert isinstance(flattened.system_message.content, str)
+ # … user message is untouched (the helper does not even look at it).
+ assert flattened.messages == [user]
+ assert isinstance(user.content, list)
+ assert len(user.content) == 2
diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py
index 5b3a03581..4cf53969d 100644
--- a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py
+++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py
@@ -1,4 +1,4 @@
-"""Tests for ``apply_litellm_prompt_caching`` in
+r"""Tests for ``apply_litellm_prompt_caching`` in
:mod:`app.agents.new_chat.prompt_caching`.
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
@@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
``litellm.completion(...)``. The tests below pin its public contract:
-1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
+1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
savings compound across multi-turn conversations on Anthropic-family
- providers.
+ providers. ``index: 0`` is used (rather than ``role: system``) because
+ the deepagent stack accumulates multiple ``SystemMessage``\ s in
+ ``state["messages"]`` and ``role: system`` would tag every one of
+ them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
@@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None:
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
- assert {"location": "message", "role": "system"} in points
+ assert {"location": "message", "index": 0} in points
assert {"location": "message", "index": -1} in points
assert len(points) == 2
+def test_does_not_inject_role_system_breakpoint() -> None:
+ """Regression: deliberately AVOID ``role: system`` so we don't tag
+ every SystemMessage the deepagent ``before_agent`` injectors push
+ into ``state["messages"]`` (priority, tree, memory, file-intent,
+ anonymous-doc). Tagging all of them overflows Anthropic's 4-block
+ ``cache_control`` cap and surfaces as
+ ``OpenrouterException: A maximum of 4 blocks with cache_control may
+ be provided. Found N`` 400s.
+ """
+ llm = _FakeLLM()
+ apply_litellm_prompt_caching(llm)
+ points = llm.model_kwargs["cache_control_injection_points"]
+ assert all(p.get("role") != "system" for p in points), (
+ f"Expected no role=system breakpoint, got: {points}"
+ )
+
+
def test_injection_points_set_for_anthropic_config() -> None:
"""Anthropic-family configs need the marker — verify it lands."""
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py
index 2ca470680..2933a0504 100644
--- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py
+++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py
@@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
)
)
assert plan.is_recency_query is False
+
+
+# ── mentioned_document_ids cross-turn drain ────────────────────────────
+
+
+class TestKnowledgePriorityMentionDrain:
+ """Regression tests for the cross-turn ``mentioned_document_ids`` drain.
+
+ The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
+ instance across turns of the same thread. ``mentioned_document_ids``
+ can therefore enter the middleware via two paths:
+
+ 1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
+ seeded by the cache-miss build on turn 1.
+ 2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
+ turn by the streaming task.
+
+ Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
+ on turn 2 would fall through to the closure (because ``[]`` is falsy in
+ Python) and replay turn 1's mentions. This class pins down the
+ correct behaviour: the runtime path is authoritative even when empty,
+ and the closure is drained the first time the runtime path fires so
+ no later turn can ever resurrect stale state.
+ """
+
+ @staticmethod
+ def _make_runtime(mention_ids: list[int]):
+ """Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
+ from types import SimpleNamespace
+
+ return SimpleNamespace(
+ context=SimpleNamespace(mentioned_document_ids=mention_ids),
+ )
+
+ @staticmethod
+ def _planner_llm() -> "FakeLLM":
+ # Planner returns a stable, non-recency plan so we always land in
+ # the hybrid-search branch (where ``fetch_mentioned_documents`` is
+ # invoked alongside the main search).
+ return FakeLLM(
+ json.dumps(
+ {
+ "optimized_query": "follow up question",
+ "start_date": None,
+ "end_date": None,
+ "is_recency_query": False,
+ }
+ )
+ )
+
+ async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
+ """Turn 1 with mentions in BOTH closure and runtime context: the
+ runtime path wins AND the closure is drained so a future turn
+ cannot replay it.
+ """
+ fetched_ids: list[list[int]] = []
+
+ async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
+ fetched_ids.append(list(document_ids))
+ return []
+
+ async def fake_search_knowledge_base(**_kwargs):
+ return []
+
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
+ fake_fetch_mentioned_documents,
+ )
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
+ fake_search_knowledge_base,
+ )
+
+ middleware = KnowledgeBaseSearchMiddleware(
+ llm=self._planner_llm(),
+ search_space_id=42,
+ mentioned_document_ids=[1, 2, 3],
+ )
+
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="what is in those docs?")]},
+ runtime=self._make_runtime([1, 2, 3]),
+ )
+
+ assert fetched_ids == [[1, 2, 3]], (
+ "runtime.context mentions must be the source of truth on turn 1"
+ )
+ assert middleware.mentioned_document_ids == [], (
+ "closure must be drained the first time the runtime path fires "
+ "so no later turn can replay stale mentions"
+ )
+
+ async def test_empty_runtime_context_does_not_replay_closure_mentions(
+ self, monkeypatch
+ ):
+ """Regression: turn 2 with NO mentions must not surface turn 1's
+ mentions from the constructor closure.
+
+ Before the fix, ``if ctx_mentions:`` treated an empty list as
+ absent and fell through to ``elif self.mentioned_document_ids:``,
+ replaying turn 1's mentions. This test pins down the corrected
+ behaviour.
+ """
+ fetched_ids: list[list[int]] = []
+
+ async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
+ fetched_ids.append(list(document_ids))
+ return []
+
+ async def fake_search_knowledge_base(**_kwargs):
+ return []
+
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
+ fake_fetch_mentioned_documents,
+ )
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
+ fake_search_knowledge_base,
+ )
+
+ # Simulate a cached middleware instance whose closure was seeded
+ # by a previous turn's cache-miss build (mentions=[1,2,3]).
+ middleware = KnowledgeBaseSearchMiddleware(
+ llm=self._planner_llm(),
+ search_space_id=42,
+ mentioned_document_ids=[1, 2, 3],
+ )
+
+ # Turn 2: streaming task supplies an EMPTY mention list (no
+ # mentions on this follow-up turn).
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="what about the next steps?")]},
+ runtime=self._make_runtime([]),
+ )
+
+ assert fetched_ids == [], (
+ "fetch_mentioned_documents must NOT be called when the runtime "
+ "context says there are no mentions for this turn"
+ )
+
+ async def test_legacy_path_fires_only_when_runtime_context_absent(
+ self, monkeypatch
+ ):
+ """Backward-compat: if a caller doesn't supply runtime.context (old
+ non-streaming code path), the closure-injected mentions are still
+ honoured exactly once and then drained.
+ """
+ fetched_ids: list[list[int]] = []
+
+ async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
+ fetched_ids.append(list(document_ids))
+ return []
+
+ async def fake_search_knowledge_base(**_kwargs):
+ return []
+
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
+ fake_fetch_mentioned_documents,
+ )
+ monkeypatch.setattr(
+ "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
+ fake_search_knowledge_base,
+ )
+
+ middleware = KnowledgeBaseSearchMiddleware(
+ llm=self._planner_llm(),
+ search_space_id=42,
+ mentioned_document_ids=[7, 8],
+ )
+
+ # First call: no runtime → legacy path uses the closure.
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="initial question")]},
+ runtime=None,
+ )
+ # Second call: still no runtime — closure already drained, so no replay.
+ await middleware.abefore_agent(
+ {"messages": [HumanMessage(content="follow up")]},
+ runtime=None,
+ )
+
+ assert fetched_ids == [[7, 8]], (
+ "legacy path must honour the closure exactly once and then drain it"
+ )
+ assert middleware.mentioned_document_ids == []
diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
index cc8157464..64e4d5157 100644
--- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
+++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
@@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model():
await _preflight_llm(fake_llm)
+@pytest.mark.asyncio
+async def test_settle_speculative_agent_build_swallows_exceptions():
+ """``_settle_speculative_agent_build`` MUST always return cleanly so the
+ caller can safely re-touch the request-scoped session afterwards.
+
+ The helper guards the parallel preflight + agent-build path: when the
+ speculative build is being discarded (429 or non-429 preflight failure)
+ we await it solely to release any in-flight ``AsyncSession`` usage —
+ the build's outcome is irrelevant. Any exception (including
+ ``CancelledError``) leaking out would skip the caller's recovery flow
+ and re-introduce the very session-concurrency hazard the helper exists
+ to prevent.
+ """
+ import asyncio
+
+ from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
+
+ async def _raises() -> None:
+ raise RuntimeError("speculative build crashed")
+
+ async def _succeeds() -> str:
+ return "agent"
+
+ async def _slow() -> None:
+ await asyncio.sleep(0.05)
+
+ for coro in (_raises(), _succeeds(), _slow()):
+ task = asyncio.create_task(coro)
+ await _settle_speculative_agent_build(task)
+ assert task.done()
+
+
+@pytest.mark.asyncio
+async def test_settle_speculative_agent_build_handles_already_done_task():
+ """Done tasks (success or failure) must still be settled without raising."""
+ import asyncio
+
+ from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
+
+ async def _ok() -> str:
+ return "ok"
+
+ async def _bad() -> None:
+ raise ValueError("nope")
+
+ ok_task = asyncio.create_task(_ok())
+ bad_task = asyncio.create_task(_bad())
+ # Drive both to completion before settling.
+ await asyncio.sleep(0)
+ await asyncio.sleep(0)
+
+ await _settle_speculative_agent_build(ok_task)
+ await _settle_speculative_agent_build(bad_task)
+ assert ok_task.result() == "ok"
+ # ``bad_task`` exception was consumed by the settle helper; calling
+ # ``.exception()`` after the fact must still return the original error
+ # (the helper observes it but doesn't clear it).
+ assert isinstance(bad_task.exception(), ValueError)
+
+
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx
index 07c11b4d6..7616d461d 100644
--- a/surfsense_web/components/pricing/pricing-section.tsx
+++ b/surfsense_web/components/pricing/pricing-section.tsx
@@ -254,8 +254,8 @@ function PricingFAQ() {
Frequently Asked Questions
- Everything you need to know about SurfSense pages, premium credits, and billing. Can't
- find what you need? Reach out at{" "}
+ Everything you need to know about SurfSense pages, premium credits, and billing.
+ Can't find what you need? Reach out at{" "}
rohan@surfsense.com
From 83378211258ca1cd50949d2f61598e00f3306289 Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sun, 3 May 2026 19:14:16 -0700
Subject: [PATCH 11/12] fix(security): manual auth endpoint leaks
---
surfsense_backend/app/app.py | 101 ++++++++++++++++++++++++-----------
1 file changed, 70 insertions(+), 31 deletions(-)
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 2c9b4f390..08194e7fb 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -595,6 +595,23 @@ async def lifespan(app: FastAPI):
def registration_allowed():
+ """Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
+
+ Despite the name, this dependency does NOT only gate registration. When
+ REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
+ that could mint or refresh a session for an attacker:
+
+ * email/password ``POST /auth/register``
+ * email/password ``POST /auth/jwt/login``
+ * the Google OAuth router (``/auth/google/authorize`` and the shared
+ ``/auth/google/callback`` handles both new signups and login for
+ existing users, so flipping this off locks both)
+ * the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
+
+ Use it as a temporary "freeze all new sessions" lever during incident
+ response. It is not a way to disable signup while keeping login working;
+ for that, override ``UserManager.oauth_callback`` instead.
+ """
if not config.REGISTRATION_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
@@ -739,32 +756,45 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers
)
-app.include_router(
- fastapi_users.get_auth_router(auth_backend),
- prefix="/auth/jwt",
- tags=["auth"],
- dependencies=[Depends(rate_limit_login)],
-)
-app.include_router(
- fastapi_users.get_register_router(UserRead, UserCreate),
- prefix="/auth",
- tags=["auth"],
- dependencies=[
- Depends(rate_limit_register),
- Depends(registration_allowed), # blocks registration when disabled
- ],
-)
-app.include_router(
- fastapi_users.get_reset_password_router(),
- prefix="/auth",
- tags=["auth"],
- dependencies=[Depends(rate_limit_password_reset)],
-)
-app.include_router(
- fastapi_users.get_verify_router(UserRead),
- prefix="/auth",
- tags=["auth"],
-)
+# Password / email-based auth routers are only mounted when not running in
+# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
+# POST /auth/register reachable, which is the bypass that allowed bots to
+# create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
+if config.AUTH_TYPE != "GOOGLE":
+ app.include_router(
+ fastapi_users.get_auth_router(auth_backend),
+ prefix="/auth/jwt",
+ tags=["auth"],
+ dependencies=[
+ Depends(rate_limit_login),
+ Depends(
+ registration_allowed
+ ), # honour REGISTRATION_ENABLED kill switch on login too
+ ],
+ )
+ app.include_router(
+ fastapi_users.get_register_router(UserRead, UserCreate),
+ prefix="/auth",
+ tags=["auth"],
+ dependencies=[
+ Depends(rate_limit_register),
+ Depends(registration_allowed),
+ ],
+ )
+ app.include_router(
+ fastapi_users.get_reset_password_router(),
+ prefix="/auth",
+ tags=["auth"],
+ dependencies=[Depends(rate_limit_password_reset)],
+ )
+ app.include_router(
+ fastapi_users.get_verify_router(UserRead),
+ prefix="/auth",
+ tags=["auth"],
+ )
+
+# /users/me (read/update profile) is needed in every auth mode, so it stays
+# mounted unconditionally.
app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
@@ -822,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
),
prefix="/auth/google",
tags=["auth"],
- dependencies=[
- Depends(registration_allowed)
- ], # blocks OAuth registration when disabled
+ # REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
+ # it blocks BOTH new OAuth signups AND login of existing OAuth users
+ # (the fastapi-users OAuth router shares one callback for create+login,
+ # so this dependency closes both paths together).
+ dependencies=[Depends(registration_allowed)],
)
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
# This endpoint performs a server-side redirect instead of returning JSON
# which fixes cross-site cookie issues where browsers don't send cookies
- # set via cross-origin fetch requests on subsequent redirects
- @app.get("/auth/google/authorize-redirect", tags=["auth"])
+ # set via cross-origin fetch requests on subsequent redirects.
+ # The registration_allowed dependency mirrors the OAuth router above so
+ # the kill switch fails fast here instead of bouncing users to Google
+ # only to 403 on the callback.
+ @app.get(
+ "/auth/google/authorize-redirect",
+ tags=["auth"],
+ dependencies=[Depends(registration_allowed)],
+ )
async def google_authorize_redirect(
request: Request,
):
From 78be75ed54bb8009c680f7a75ec7ef71e939304b Mon Sep 17 00:00:00 2001
From: "DESKTOP-RTLN3BA\\$punk"
Date: Sun, 3 May 2026 19:14:51 -0700
Subject: [PATCH 12/12] fix(security): manual auth endpoint leaks
---
surfsense_backend/app/app.py | 101 ++++++++++++++++++++++++-----------
1 file changed, 70 insertions(+), 31 deletions(-)
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 2c9b4f390..08194e7fb 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -595,6 +595,23 @@ async def lifespan(app: FastAPI):
def registration_allowed():
+ """Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
+
+ Despite the name, this dependency does NOT only gate registration. When
+ REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
+ that could mint or refresh a session for an attacker:
+
+ * email/password ``POST /auth/register``
+ * email/password ``POST /auth/jwt/login``
+ * the Google OAuth router (``/auth/google/authorize`` and the shared
+ ``/auth/google/callback`` handles both new signups and login for
+ existing users, so flipping this off locks both)
+ * the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
+
+ Use it as a temporary "freeze all new sessions" lever during incident
+ response. It is not a way to disable signup while keeping login working;
+ for that, override ``UserManager.oauth_callback`` instead.
+ """
if not config.REGISTRATION_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
@@ -739,32 +756,45 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers
)
-app.include_router(
- fastapi_users.get_auth_router(auth_backend),
- prefix="/auth/jwt",
- tags=["auth"],
- dependencies=[Depends(rate_limit_login)],
-)
-app.include_router(
- fastapi_users.get_register_router(UserRead, UserCreate),
- prefix="/auth",
- tags=["auth"],
- dependencies=[
- Depends(rate_limit_register),
- Depends(registration_allowed), # blocks registration when disabled
- ],
-)
-app.include_router(
- fastapi_users.get_reset_password_router(),
- prefix="/auth",
- tags=["auth"],
- dependencies=[Depends(rate_limit_password_reset)],
-)
-app.include_router(
- fastapi_users.get_verify_router(UserRead),
- prefix="/auth",
- tags=["auth"],
-)
+# Password / email-based auth routers are only mounted when not running in
+# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
+# POST /auth/register reachable, which is the bypass that allowed bots to
+# create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
+if config.AUTH_TYPE != "GOOGLE":
+ app.include_router(
+ fastapi_users.get_auth_router(auth_backend),
+ prefix="/auth/jwt",
+ tags=["auth"],
+ dependencies=[
+ Depends(rate_limit_login),
+ Depends(
+ registration_allowed
+ ), # honour REGISTRATION_ENABLED kill switch on login too
+ ],
+ )
+ app.include_router(
+ fastapi_users.get_register_router(UserRead, UserCreate),
+ prefix="/auth",
+ tags=["auth"],
+ dependencies=[
+ Depends(rate_limit_register),
+ Depends(registration_allowed),
+ ],
+ )
+ app.include_router(
+ fastapi_users.get_reset_password_router(),
+ prefix="/auth",
+ tags=["auth"],
+ dependencies=[Depends(rate_limit_password_reset)],
+ )
+ app.include_router(
+ fastapi_users.get_verify_router(UserRead),
+ prefix="/auth",
+ tags=["auth"],
+ )
+
+# /users/me (read/update profile) is needed in every auth mode, so it stays
+# mounted unconditionally.
app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
@@ -822,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
),
prefix="/auth/google",
tags=["auth"],
- dependencies=[
- Depends(registration_allowed)
- ], # blocks OAuth registration when disabled
+ # REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
+ # it blocks BOTH new OAuth signups AND login of existing OAuth users
+ # (the fastapi-users OAuth router shares one callback for create+login,
+ # so this dependency closes both paths together).
+ dependencies=[Depends(registration_allowed)],
)
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
# This endpoint performs a server-side redirect instead of returning JSON
# which fixes cross-site cookie issues where browsers don't send cookies
- # set via cross-origin fetch requests on subsequent redirects
- @app.get("/auth/google/authorize-redirect", tags=["auth"])
+ # set via cross-origin fetch requests on subsequent redirects.
+ # The registration_allowed dependency mirrors the OAuth router above so
+ # the kill switch fails fast here instead of bouncing users to Google
+ # only to 403 on the callback.
+ @app.get(
+ "/auth/google/authorize-redirect",
+ tags=["auth"],
+ dependencies=[Depends(registration_allowed)],
+ )
async def google_authorize_redirect(
request: Request,
):