diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index e356bd3e5..ad1c128bc 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -144,6 +144,11 @@ jobs: APPLE_ID: ${{ secrets.APPLE_ID }} APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + # TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed. + # Surfaces the exact codesign / notarize commands electron-builder spawns, + # so we can see which subprocess hangs. + DEBUG: electron-builder,electron-osx-sign*,@electron/notarize* + ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true" # Service principal credentials for Azure.Identity EnvironmentCredential used by the # TrustedSigning PowerShell module. Only populated when signing is enabled. # electron-builder 26 does not yet support OIDC federated tokens for Azure signing, diff --git a/.github/workflows/notary-status.yml b/.github/workflows/notary-status.yml new file mode 100644 index 000000000..5c7c42038 --- /dev/null +++ b/.github/workflows/notary-status.yml @@ -0,0 +1,60 @@ +name: Notary status check + +# One-off diagnostic workflow. Queries Apple's notary service to see if your +# submissions are queued, in progress, accepted, or rejected. Useful when a +# notarization seems "hung" — most often the queue itself, especially on a +# brand-new Apple Developer account. +# +# Run via: Actions tab -> "Notary status check" -> Run workflow. +# Inputs are optional; if you provide a submission ID, it also fetches that +# submission's full Apple log. +# +# Safe to delete after diagnosis. + +on: + workflow_dispatch: + inputs: + submission_id: + description: 'Optional: submission UUID to fetch full Apple log for' + required: false + default: '' + +jobs: + status: + runs-on: macos-latest + steps: + - name: List recent notarization submissions + env: + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + run: | + set -euo pipefail + echo "::group::Submission history (most recent first)" + xcrun notarytool history \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" + echo "::endgroup::" + + - name: Inspect specific submission (if id provided) + if: ${{ inputs.submission_id != '' }} + env: + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + SUBMISSION_ID: ${{ inputs.submission_id }} + run: | + set -euo pipefail + echo "::group::Submission info" + xcrun notarytool info "$SUBMISSION_ID" \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" + echo "::endgroup::" + echo "::group::Apple's processing log for this submission" + xcrun notarytool log "$SUBMISSION_ID" \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" || true + echo "::endgroup::" diff --git a/.vscode/launch.json b/.vscode/launch.json index 029e7c647..ad8f8f2a7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -26,7 +26,16 @@ "pythonArgs": [ "run", "python" - ] + ], + // Mute LangGraph/Pydantic checkpoint serializer warnings + // (UserWarnings emitted from pydantic/main.py when the + // runtime snapshots a SurfSenseContextSchema into a field + // typed `None`) so the debugger's "Raised Exceptions" + // breakpoint doesn't pause on a known-harmless event. + // Production logs are unaffected. + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Backend: FastAPI (No Reload)", @@ -40,7 +49,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Backend: FastAPI (main.py)", @@ -54,7 +66,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Frontend: Next.js", @@ -104,7 +119,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Celery: Beat Scheduler", @@ -124,7 +142,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } } ], "compounds": [ diff --git a/VERSION b/VERSION index 44517d518..236c7ad08 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.19 +0.0.21 diff --git a/docker/.env.example b/docker/.env.example index 95de0cf85..fd56bdccc 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_BATCH_SIZE=100 -# Premium token purchases ($1 per 1M tokens for premium-tier models) +# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of +# credit; premium turns debit the actual per-call provider cost +# reported by LiteLLM, so cheap and expensive models bill proportionally) # STRIPE_TOKEN_BUYING_ENABLED=FALSE # STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -# STRIPE_TOKENS_PER_UNIT=1000000 +# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000 # ------------------------------------------------------------------------------ # TTS & STT (Text-to-Speech / Speech-to-Text) @@ -305,6 +308,24 @@ STT_SERVICE=local/base # Advanced (optional) # ------------------------------------------------------------------------------ +# New-chat agent feature flags +SURFSENSE_ENABLE_CONTEXT_EDITING=true +SURFSENSE_ENABLE_COMPACTION_V2=true +SURFSENSE_ENABLE_RETRY_AFTER=true +SURFSENSE_ENABLE_MODEL_FALLBACK=false +SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true +SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true +SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true +SURFSENSE_ENABLE_BUSY_MUTEX=true +SURFSENSE_ENABLE_SKILLS=true +SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true +SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true +SURFSENSE_ENABLE_ACTION_LOG=true +SURFSENSE_ENABLE_REVERT_ROUTE=true +SURFSENSE_ENABLE_PERMISSION=true +SURFSENSE_ENABLE_DOOM_LOOP=true +SURFSENSE_ENABLE_STREAM_PARITY_V2=true + # Periodic connector sync interval (default: 5m) # SCHEDULE_CHECKER_INTERVAL=5m @@ -315,9 +336,24 @@ STT_SERVICE=local/base # Pages limit per user for ETL (default: unlimited) # PAGES_LIMIT=500 -# Premium token quota per registered user (default: 5M) -# Only applies to models with billing_tier=premium in global_llm_config.yaml -# PREMIUM_TOKEN_LIMIT=5000000 +# Premium credit quota per registered user, in micro-USD (default: $5). +# Premium turns are debited at the actual per-call provider cost reported +# by LiteLLM. Only applies to models with billing_tier=premium. +# PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default). +# QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default). +# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation for the podcast Celery task ($0.20 default). +# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation for the video Celery task ($1.00 default). +# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — public users can chat without an account # Set TRUE to enable /free pages and anonymous chat API diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 98d396363..ba89059c8 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000 # Set FALSE to disable new checkout session creation temporarily STRIPE_PAGE_BUYING_ENABLED=TRUE -# Premium token purchases via Stripe (for premium-tier model usage) -# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens) +# Premium credit purchases via Stripe (for premium-tier model usage). +# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit +# (default 1_000_000 = $1.00). Premium turns are billed at the actual +# per-call provider cost reported by LiteLLM. STRIPE_TOKEN_BUYING_ENABLED=FALSE STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -STRIPE_TOKENS_PER_UNIT=1000000 +STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping): +# STRIPE_TOKENS_PER_UNIT=1000000 # Periodic Stripe safety net for purchases left in PENDING (minutes old) STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 @@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300 # (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) PAGES_LIMIT=500 -# Premium token quota per registered user (default: 3,000,000) -# Applies only to models with billing_tier=premium in global_llm_config.yaml -PREMIUM_TOKEN_LIMIT=3000000 +# Premium credit quota per registered user, in micro-USD +# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the +# actual per-call provider cost reported by LiteLLM, so cheap and expensive +# models bill proportionally. Applies only to models with +# billing_tier=premium in global_llm_config.yaml. +PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping): +# PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD. +# stream_new_chat estimates an upper-bound cost from the model's +# litellm-published per-token rates × the config's quota_reserve_tokens +# and clamps to this value so a misconfigured model can't lock the +# user's whole balance on one call. Default $1.00. +QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation (in micro-USD) for the POST /image-generations +# endpoint. Bypassed for free configs. Default $0.05. +QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation (in micro-USD) used by the podcast Celery task. +# Single envelope covers one transcript-generation LLM call. Default $0.20. +QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation (in micro-USD) used by the video +# presentation Celery task. Covers worst-case fan-out of N slide-scene +# generations + refines. Default $1.00. NOTE: tasks using the override +# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — allows public users to chat without an account # Set TRUE to enable /free pages and anonymous chat API @@ -297,3 +327,30 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names # SURFSENSE_ALLOWED_PLUGINS=year_substituter + +# ----------------------------------------------------------------------------- +# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON) +# ----------------------------------------------------------------------------- +# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU +# on a cold turn) is reused across subsequent turns on the same thread, +# collapsing it to a microsecond hash lookup. All connector tools acquire +# their own short-lived DB session per call (Phase 2 refactor) so a cached +# closure is safe to share across requests. Flip OFF only as a last-resort +# rollback if you suspect cache-related staleness. +# SURFSENSE_ENABLE_AGENT_CACHE=true + +# Cache capacity (max number of compiled-agent entries kept in memory) +# and TTL per entry (seconds). Working set is typically one entry per +# active thread on this replica; tune up for very large deployments. +# SURFSENSE_AGENT_CACHE_MAXSIZE=256 +# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800 + +# ----------------------------------------------------------------------------- +# Connector discovery TTL cache (Phase 1.4 perf optimization) +# ----------------------------------------------------------------------------- +# Caches the per-search-space "available connectors" + "available document +# types" lookups that ``create_surfsense_deep_agent`` hits on every turn. +# ORM event listeners auto-invalidate on connector / document inserts, +# updates and deletes — the TTL only bounds staleness for bulk-import +# paths that bypass the ORM. Set to 0 to disable the cache. +# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30 diff --git a/surfsense_backend/Dockerfile b/surfsense_backend/Dockerfile index 1222b36b6..73d5819b9 100644 --- a/surfsense_backend/Dockerfile +++ b/surfsense_backend/Dockerfile @@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs COPY pyproject.toml . COPY uv.lock . -# Install PyTorch based on architecture -RUN if [ "$(uname -m)" = "x86_64" ]; then \ - pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \ - else \ - pip install --no-cache-dir torch torchvision torchaudio; \ - fi - -# Install python dependencies +# Install all Python dependencies from uv.lock for deterministic builds. +# +# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock, +# which lets prod silently drift to newer upstream versions on every rebuild +# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports). +# Exporting the lock to requirements.txt and feeding it to `uv pip install` +# pins every transitive package to the exact version captured in uv.lock. +# +# Note on torch/CUDA: we do NOT install torch from a separate cu* index here. +# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull +# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all +# captured in uv.lock). Installing from cu121 first only wasted ~2GB of +# downloads that the lock-based install immediately replaced. If a specific +# CUDA version is needed (driver compatibility, etc.), wire it through +# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth. RUN pip install --no-cache-dir uv && \ - uv pip install --system --no-cache-dir -e . + uv export --frozen --no-dev --no-hashes --no-emit-project \ + --format requirements-txt -o /tmp/requirements.txt && \ + uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \ + rm /tmp/requirements.txt # Set SSL environment variables dynamically RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \ @@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr # Pre-download Docling models RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true -# Install Playwright browsers for web scraping if needed -RUN pip install playwright && \ - playwright install chromium --with-deps +# Install Playwright browsers for web scraping (the playwright package itself +# is already installed via uv.lock above) +RUN playwright install chromium --with-deps # Copy source code COPY . . +# Install the project itself in editable mode. Dependencies were already +# installed deterministically from uv.lock above, so --no-deps prevents any +# re-resolution that could pull newer versions. +RUN uv pip install --system --no-cache-dir --no-deps -e . + # Copy and set permissions for entrypoint script # Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts) COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py new file mode 100644 index 000000000..fba621a0c --- /dev/null +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -0,0 +1,44 @@ +"""138_add_thread_auto_model_pinning_fields + +Revision ID: 138 +Revises: 137 +Create Date: 2026-04-30 + +Add a single thread-level column to persist the Auto (Fastest) model pin: +- pinned_llm_config_id: concrete resolved global LLM config id used for this + thread. NULL means "no pin; Auto will resolve on next turn". + +The column is unindexed: all reads are by new_chat_threads.id (primary key), +so a secondary index would be dead write amplification. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "138" +down_revision: str | None = "137" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" + ) + + +def downgrade() -> None: + # Drop any shape the thread row may be carrying. The extra columns and + # indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS + # makes each statement a safe no-op on the lean shape. + op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode") + op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id") + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at") + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode") + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id" + ) diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py new file mode 100644 index 000000000..83c96a429 --- /dev/null +++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py @@ -0,0 +1,160 @@ +"""add user table to zero_publication with column list + +Adds the "user" table to zero_publication with a column-list publication +so that only the 5 fields driving the live usage meters are replicated +through WAL -> zero-cache -> browser IndexedDB: + + id, pages_limit, pages_used, + premium_tokens_limit, premium_tokens_used + +Sensitive columns (hashed_password, email, oauth_account, display_name, +avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT +included in the publication, so they never enter WAL replication. + +Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency +(it is already DEFAULT today since "user" was never in the +TABLES_WITH_FULL_IDENTITY list of migration 117). + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Revision ID: 139 +Revises: 138 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "139" +down_revision: str | None = "138" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Document column list as left by migration 117. Must match exactly. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Five fields needed by the live usage meters (sidebar Tokens/Pages, +# Buy Tokens content). Keep this list narrow on purpose: anything added +# here flows into WAL and IndexedDB for every connected browser. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl( + documents_has_zero_ver: bool, user_has_zero_ver: bool +) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state" + ) + + +def upgrade() -> None: + conn = op.get_bind() + # asyncpg requires LOCK TABLE inside a transaction block. Alembic already + # opened one via context.begin_transaction(), but the driver still errors + # unless we use an explicit SAVEPOINT (nested transaction) for this block. + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of + # migration 117, so this is already DEFAULT. Re-assert anyway so + # the column-list publication stays valid (DEFAULT identity only + # requires the PK to be in the column list). + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + + conn.execute( + sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver)) + ) + + +def downgrade() -> None: + conn = op.get_bind() + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + documents_has_zero_ver = _has_zero_version(conn, "documents") + conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver))) diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py new file mode 100644 index 000000000..64aa699e8 --- /dev/null +++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py @@ -0,0 +1,291 @@ +"""rename premium token columns to credit micros and add cost_micros to token_usage + +Migrates the premium quota system from a flat token counter to a USD-cost +based credit system, where 1 credit = 1 micro-USD ($0.000001). + +Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe +price means every existing value is already correct in the new unit, no +data transformation needed): + + user.premium_tokens_limit -> premium_credit_micros_limit + user.premium_tokens_used -> premium_credit_micros_used + user.premium_tokens_reserved -> premium_credit_micros_reserved + + premium_token_purchases.tokens_granted -> credit_micros_granted + +New column for cost auditing per turn: + + token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0) + +The "user" table is in zero_publication's column list (added in 139), so +this migration must drop and recreate the publication with the renamed +column names, otherwise zero-cache will replicate stale column names and +the FE Zero schema will fail to bind. + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on +"user". Skipping the data-volume reset will leave IndexedDB clients seeing +column-not-found errors from a stale catalog snapshot. + +Revision ID: 140 +Revises: 139 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "140" +down_revision: str | None = "139" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Replicates 139's document column list verbatim — must stay in sync. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Same five live-meter fields as 139, with the renamed column names. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_credit_micros_limit", + "premium_credit_micros_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _column_exists(conn, table: str, column: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = :col" + ), + {"tbl": table, "col": column}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl( + user_cols: list[str], + *, + documents_has_zero_ver: bool, + user_has_zero_ver: bool, +) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_col_list_with_meta = user_cols + ( + ['"_0_version"'] if user_has_zero_ver else [] + ) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_col_list_with_meta) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def upgrade() -> None: + conn = op.get_bind() + + # ------------------------------------------------------------------ + # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in + # dev environments are safe. + # ------------------------------------------------------------------ + if not _column_exists(conn, "token_usage", "cost_micros"): + op.add_column( + "token_usage", + sa.Column( + "cost_micros", + sa.BigInteger(), + nullable=False, + server_default="0", + ), + ) + + # ------------------------------------------------------------------ + # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted. + # ------------------------------------------------------------------ + if _column_exists( + conn, "premium_token_purchases", "tokens_granted" + ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"): + op.alter_column( + "premium_token_purchases", + "tokens_granted", + new_column_name="credit_micros_granted", + ) + + # ------------------------------------------------------------------ + # 3. Rename user.premium_tokens_* -> premium_credit_micros_*. + # + # We must drop the publication first (it references the old column + # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE + # in a transaction block; alembic's outer transaction already holds + # one, but a SAVEPOINT keeps the LOCK + DDL atomic. + # ------------------------------------------------------------------ + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list + # publications require at least the PK to be in the column list, + # which is true for both the old and new shape. + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + # Drop the publication BEFORE renaming columns, otherwise Postgres + # rejects the rename: "cannot drop column ... referenced by + # publication". + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for old, new in ( + ("premium_tokens_limit", "premium_credit_micros_limit"), + ("premium_tokens_used", "premium_credit_micros_used"), + ("premium_tokens_reserved", "premium_credit_micros_reserved"), + ): + if _column_exists(conn, "user", old) and not _column_exists( + conn, "user", new + ): + op.alter_column("user", old, new_column_name=new) + + # Update the server_default on the renamed limit column so newly + # inserted users get $5 of credit (== 5_000_000 micros) by + # default. Existing rows are unaffected. + op.alter_column( + "user", + "premium_credit_micros_limit", + server_default="5000000", + ) + + # Recreate the publication with the new column names. + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + USER_COLS, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + +def downgrade() -> None: + """Revert the rename and drop ``cost_micros``. + + Mirrors ``upgrade``: drop the publication, rename columns back, drop + the new column, recreate the publication with the old column list. + Same zero-cache stop/reset runbook applies in reverse. + """ + conn = op.get_bind() + + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for new, old in ( + ("premium_credit_micros_limit", "premium_tokens_limit"), + ("premium_credit_micros_used", "premium_tokens_used"), + ("premium_credit_micros_reserved", "premium_tokens_reserved"), + ): + if _column_exists(conn, "user", new) and not _column_exists( + conn, "user", old + ): + op.alter_column("user", new, new_column_name=old) + + op.alter_column( + "user", + "premium_tokens_limit", + server_default="5000000", + ) + + legacy_user_cols = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", + ] + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + legacy_user_cols, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + if _column_exists( + conn, "premium_token_purchases", "credit_micros_granted" + ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"): + op.alter_column( + "premium_token_purchases", + "credit_micros_granted", + new_column_name="tokens_granted", + ) + + if _column_exists(conn, "token_usage", "cost_micros"): + op.drop_column("token_usage", "cost_micros") diff --git a/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py b/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py new file mode 100644 index 000000000..9a27e7ed0 --- /dev/null +++ b/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py @@ -0,0 +1,66 @@ +"""141_unique_chat_message_turn_role + +Revision ID: 141 +Revises: 140 +Create Date: 2026-05-04 + +Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)`` +where ``turn_id IS NOT NULL``. + +Why +--- +The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being +moved to write its own ``new_chat_messages`` rows server-side instead of +relying on the frontend's later ``POST /threads/{id}/messages`` call. This +closes the "ghost-thread" abuse vector where authenticated callers got free +LLM completions while ``new_chat_messages`` stayed empty. + +For server-side and legacy frontend writes to coexist we need an idempotency +key. The natural triple is ``(thread_id, turn_id, role)``: the server issues +exactly one ``turn_id`` per turn, and a turn produces at most one user +message and one assistant message. Whichever side wins the race writes the +row; the loser hits ``IntegrityError`` and recovers gracefully. + +Partial — ``WHERE turn_id IS NOT NULL`` — so: + + * Legacy rows that predate the ``turn_id`` column (migration 136) keep + co-existing without de-dup. + * Clone / snapshot inserts in + ``app/services/public_chat_service.py`` that build ``NewChatMessage`` + without ``turn_id`` are unaffected (multiple snapshot copies of the same + user/assistant pair are intentional). + +This index coexists with the existing single-column ``ix_new_chat_messages_turn_id`` +from migration 136 — no collision. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "141" +down_revision: str | None = "140" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +INDEX_NAME = "uq_new_chat_messages_thread_turn_role" +TABLE_NAME = "new_chat_messages" + + +def upgrade() -> None: + op.create_index( + INDEX_NAME, + TABLE_NAME, + ["thread_id", "turn_id", "role"], + unique=True, + postgresql_where=sa.text("turn_id IS NOT NULL"), + ) + + +def downgrade() -> None: + op.drop_index(INDEX_NAME, table_name=TABLE_NAME) diff --git a/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py b/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py new file mode 100644 index 000000000..43b30a756 --- /dev/null +++ b/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py @@ -0,0 +1,134 @@ +"""142_token_usage_message_id_unique + +Revision ID: 142 +Revises: 141 +Create Date: 2026-05-04 + +Add a partial unique index on ``token_usage(message_id)`` where +``message_id IS NOT NULL``. + +Why +--- +Two writers can race on the same assistant turn's ``token_usage`` row: + + * ``finalize_assistant_turn`` (server-side, called from the streaming + finally block in ``stream_new_chat`` / ``stream_resume_chat``) + * ``append_message``'s recovery branch in + ``app/routes/new_chat_routes.py`` (legacy frontend round-trip) + +Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a +micro-second-aligned race could observe "no row" on each side and double +INSERT, producing duplicate ``token_usage`` rows for the same +``message_id``. + +A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``) +turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING`` +no-ops for the loser, hard-eliminating the race at the DB level. Partial +because non-chat usage rows (indexing, image generation, podcasts) keep +``message_id`` NULL — they're per-event, no de-dup needed. + +Pre-flight +---------- +Today's schema only has a non-unique index on ``message_id`` so a +duplicate population could already exist from any past race. We: + + * Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``). + * If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50) + we dedupe by deleting all but the smallest ``id`` per group. + * If the count exceeds the threshold we abort with a descriptive + error rather than silently mutate prod data — operator must + investigate before retrying. + +Concurrency +----------- +``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid +stalling production writes during deploy (a regular ``CREATE INDEX`` +holds an ACCESS EXCLUSIVE lock for the duration of the build, which +would block ``token_usage`` INSERTs for every active streaming chat). +The trade-off is a slower migration (CONCURRENTLY scans the table +twice) and the ``CREATE`` statement cannot run inside alembic's default +transaction wrapper — ``autocommit_block()`` handles that. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "142" +down_revision: str | None = "141" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +INDEX_NAME = "uq_token_usage_message_id" +TABLE_NAME = "token_usage" + +# Refuse to silently mutate prod data if the duplicate population is +# unexpectedly large — operator should investigate the upstream cause +# before retrying. 50 is comfortably above any plausible duplicate +# count from the existing race window (the race is microseconds wide). +DUPLICATE_ABORT_THRESHOLD = 50 + + +def upgrade() -> None: + conn = op.get_bind() + + dup_groups = conn.execute( + sa.text( + "SELECT message_id, COUNT(*) AS n " + "FROM token_usage " + "WHERE message_id IS NOT NULL " + "GROUP BY message_id " + "HAVING COUNT(*) > 1" + ) + ).fetchall() + + if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD: + raise RuntimeError( + f"token_usage has {len(dup_groups)} duplicate message_id groups " + f"(threshold={DUPLICATE_ABORT_THRESHOLD}). " + "Resolve the duplicates manually before re-running this migration." + ) + + if dup_groups: + # Delete all but the smallest-id row per duplicate group. The + # smallest id is by definition the earliest insert, so we keep + # the row most likely to reflect the actual stream's first + # successful write. + conn.execute( + sa.text( + """ + DELETE FROM token_usage + WHERE id IN ( + SELECT id FROM ( + SELECT + id, + row_number() OVER ( + PARTITION BY message_id ORDER BY id ASC + ) AS rn + FROM token_usage + WHERE message_id IS NOT NULL + ) ranked + WHERE rn > 1 + ) + """ + ) + ) + + # CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop + # alembic's auto-transaction for this op only. + with op.get_context().autocommit_block(): + op.execute( + f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} " + f"ON {TABLE_NAME} (message_id) " + "WHERE message_id IS NOT NULL" + ) + + +def downgrade() -> None: + with op.get_context().autocommit_block(): + op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}") diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py index d3acba175..7afa30a31 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py @@ -11,7 +11,6 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer -from .middleware import build_main_agent_deepagent_middleware from app.agents.multi_agent_chat.subagents.shared.permissions import ( ToolsPermissions, ) @@ -20,6 +19,8 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.db import ChatVisibility +from .middleware import build_main_agent_deepagent_middleware + def build_compiled_agent_graph_sync( *, diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py index cb387278b..d23dc33a9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py @@ -31,8 +31,8 @@ from .propagation import ( from .resume import ( build_resume_command, fan_out_decisions_to_match, - hitlrequest_action_count, get_first_pending_subagent_interrupt, + hitlrequest_action_count, ) logger = logging.getLogger(__name__) @@ -51,7 +51,9 @@ def build_task_tool_with_parent_config( ) if task_description is None: - description = TASK_TOOL_DESCRIPTION.format(available_agents=subagent_description_str) + description = TASK_TOOL_DESCRIPTION.format( + available_agents=subagent_description_str + ) elif "{available_agents}" in task_description: description = task_description.format(available_agents=subagent_description_str) else: @@ -90,11 +92,11 @@ def build_task_tool_with_parent_config( def task( description: Annotated[ str, - "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501 + "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", ], subagent_type: Annotated[ str, - "The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501 + "The type of subagent to use. Must be one of the available agent types listed in the tool description.", ], runtime: ToolRuntime, ) -> str | Command: @@ -119,7 +121,9 @@ def build_task_tool_with_parent_config( if callable(get_state): try: snapshot = get_state(sub_config) - pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) + pending_id, pending_value = get_first_pending_subagent_interrupt( + snapshot + ) except Exception: # Fail loud if a resume is queued: silent fallback would # replay the original interrupt to the user. @@ -158,11 +162,11 @@ def build_task_tool_with_parent_config( async def atask( description: Annotated[ str, - "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501 + "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", ], subagent_type: Annotated[ str, - "The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501 + "The type of subagent to use. Must be one of the available agent types listed in the tool description.", ], runtime: ToolRuntime, ) -> str | Command: @@ -186,7 +190,9 @@ def build_task_tool_with_parent_config( if callable(aget_state): try: snapshot = await aget_state(sub_config) - pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) + pending_id, pending_value = get_first_pending_subagent_interrupt( + snapshot + ) except Exception: if has_surfsense_resume(runtime): logger.exception( diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py index 1eae3a519..74e47cfab 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py @@ -23,7 +23,6 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer -from ...context_prune.prune_tool_names import safe_exclude_tools from app.agents.multi_agent_chat.subagents import ( build_subagents, get_subagents_to_exclude, @@ -66,6 +65,7 @@ from app.agents.new_chat.plugin_loader import ( from app.agents.new_chat.tools.registry import BUILTIN_TOOLS from app.db import ChatVisibility +from ...context_prune.prune_tool_names import safe_exclude_tools from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py index 6dd3eb721..6a6fd39b7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py @@ -14,8 +14,10 @@ from langchain_core.tools import BaseTool from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession -from ..graph.compile_graph_sync import build_compiled_agent_graph_sync -from ..tools import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED +from app.agents.multi_agent_chat.subagents import ( + get_subagents_to_exclude, + main_prompt_registry_subagent_lines, +) from app.agents.multi_agent_chat.subagents.mcp_tools.index import ( load_mcp_tools_by_connector, ) @@ -24,17 +26,19 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import AgentConfig -from app.agents.multi_agent_chat.subagents import ( - get_subagents_to_exclude, - main_prompt_registry_subagent_lines, -) -from ..system_prompt import build_main_agent_system_prompt from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool from app.agents.new_chat.tools.registry import build_tools_async from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger +from ..graph.compile_graph_sync import build_compiled_agent_graph_sync +from ..system_prompt import build_main_agent_system_prompt +from ..tools import ( + MAIN_AGENT_SURFSENSE_TOOL_NAMES, + MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED, +) + _perf_log = get_perf_logger() diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py index 914257521..80e86e5c8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py @@ -2,6 +2,9 @@ from __future__ import annotations -from .index import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED +from .index import ( + MAIN_AGENT_SURFSENSE_TOOL_NAMES, + MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED, +) __all__ = ["MAIN_AGENT_SURFSENSE_TOOL_NAMES", "MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index c2ebc2029..938e73bd4 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -13,7 +13,9 @@ from .resume import create_generate_resume_tool from .video_presentation import create_generate_video_presentation_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: resolved_dependencies = {**(dependencies or {}), **kwargs} podcast = create_generate_podcast_tool( search_space_id=resolved_dependencies["search_space_id"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py index 4ff02856f..6c65b2cee 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py @@ -10,7 +10,9 @@ from app.db import ChatVisibility from .update_memory import create_update_memory_tool, create_update_team_memory_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: resolved_dependencies = {**(dependencies or {}), **kwargs} if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE: mem = create_update_team_memory_tool( @@ -18,7 +20,10 @@ def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> db_session=resolved_dependencies["db_session"], llm=resolved_dependencies.get("llm"), ) - return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []} + return { + "allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], + "ask": [], + } mem = create_update_memory_tool( user_id=resolved_dependencies["user_id"], db_session=resolved_dependencies["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py index 350dab563..3546d4d01 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py @@ -11,14 +11,20 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool from .web_search import create_web_search_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: resolved_dependencies = {**(dependencies or {}), **kwargs} web = create_web_search_tool( search_space_id=resolved_dependencies.get("search_space_id"), available_connectors=resolved_dependencies.get("available_connectors"), ) - scrape = create_scrape_webpage_tool(firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key")) - docs = create_search_surfsense_docs_tool(db_session=resolved_dependencies["db_session"]) + scrape = create_scrape_webpage_tool( + firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key") + ) + docs = create_search_surfsense_docs_tool( + db_session=resolved_dependencies["db_session"] + ) return { "allow": [ {"name": getattr(web, "name", "") or "", "tool": web}, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py index 9bbfdccb9..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py @@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ) -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: _ = {**(dependencies or {}), **kwargs} return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py index 57d8e277d..2538a494b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py @@ -12,7 +12,9 @@ from .search_events import create_search_calendar_events_tool from .update_event import create_update_calendar_event_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: resolved_dependencies = {**(dependencies or {}), **kwargs} session_dependencies = { "db_session": resolved_dependencies["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py index 9bbfdccb9..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py @@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ) -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: _ = {**(dependencies or {}), **kwargs} return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py index 561ea44ab..28c4ee6ee 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py @@ -11,7 +11,9 @@ from .delete_page import create_delete_confluence_page_tool from .update_page import create_update_confluence_page_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: resolved_dependencies = {**(dependencies or {}), **kwargs} session_dependencies = { "db_session": resolved_dependencies["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py index 04db7cda6..c0a3bf3c9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py @@ -11,7 +11,9 @@ from .read_messages import create_read_discord_messages_tool from .send_message import create_send_discord_message_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py index a25755a8d..5864ae972 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py @@ -10,7 +10,9 @@ from .create_file import create_create_dropbox_file_tool from .trash_file import create_delete_dropbox_file_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py index c355536e8..09082d091 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py @@ -14,7 +14,9 @@ from .trash_email import create_trash_gmail_email_tool from .update_draft import create_update_gmail_draft_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py index 85f414d14..7dbee87a0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py @@ -10,7 +10,9 @@ from .create_file import create_create_google_drive_file_tool from .trash_file import create_delete_google_drive_file_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py index 9d32c320a..342f120be 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py @@ -11,7 +11,9 @@ from .delete_issue import create_delete_jira_issue_tool from .update_issue import create_update_jira_issue_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py index dc147055e..f1ee49964 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py @@ -11,7 +11,9 @@ from .delete_issue import create_delete_linear_issue_tool from .update_issue import create_update_linear_issue_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py index 053a905a3..47b303295 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py @@ -11,7 +11,9 @@ from .list_events import create_list_luma_events_tool from .read_event import create_read_luma_event_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py index 0323781e5..c78f630a1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py @@ -11,7 +11,9 @@ from .delete_page import create_delete_notion_page_tool from .update_page import create_update_notion_page_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py index 5f40ba704..9a2dadd36 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py @@ -10,7 +10,9 @@ from .create_file import create_create_onedrive_file_tool from .trash_file import create_delete_onedrive_file_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py index 9bbfdccb9..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py @@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ) -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: _ = {**(dependencies or {}), **kwargs} return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py index 4bc481307..cbe76b040 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py @@ -11,7 +11,9 @@ from .read_messages import create_read_teams_messages_tool from .send_message import create_send_teams_message_tool -def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: +def load_tools( + *, dependencies: dict[str, Any] | None = None, **kwargs: Any +) -> ToolsPermissions: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py index 68c6ce995..79ab3db10 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py @@ -31,6 +31,7 @@ logger = logging.getLogger(__name__) ## Helper functions for fetching connector metadata maps + async def fetch_mcp_connector_metadata_maps( session: AsyncSession, search_space_id: int, @@ -58,6 +59,7 @@ async def fetch_mcp_connector_metadata_maps( ## Helper functions for partitioning tools by connector agent + def partition_mcp_tools_by_connector( tools: Sequence[BaseTool], connector_id_to_type: dict[int, str], @@ -104,8 +106,10 @@ def partition_mcp_tools_by_connector( return dict(buckets) + ## Helper functions for splitting tools by permissions + def _get_mcp_tool_name(tool: BaseTool) -> str: meta: dict[str, Any] = getattr(tool, "metadata", None) or {} orig = meta.get("mcp_original_tool_name") @@ -139,6 +143,7 @@ def _split_tools_by_permissions( ## Main function to load MCP tools and split them by permissions for each connector agent + async def load_mcp_tools_by_connector( session: AsyncSession, search_space_id: int, @@ -148,9 +153,7 @@ async def load_mcp_tools_by_connector( Pass ``bypass_internal_hitl=True`` so the subagent's ``HumanInTheLoopMiddleware`` is the single HITL gate. """ - flat = await load_mcp_tools( - session, search_space_id, bypass_internal_hitl=True - ) + flat = await load_mcp_tools(session, search_space_id, bypass_internal_hitl=True) id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id) buckets = partition_mcp_tools_by_connector(flat, id_map, name_map) return { diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py index 51906858a..1b7a19ad7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py @@ -8,6 +8,9 @@ from typing import Any, Protocol from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from app.agents.multi_agent_chat.constants import ( + SUBAGENT_TO_REQUIRED_CONNECTOR_MAP, +) from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import ( build_subagent as build_deliverables_subagent, ) @@ -62,9 +65,6 @@ from app.agents.multi_agent_chat.subagents.connectors.slack.agent import ( from app.agents.multi_agent_chat.subagents.connectors.teams.agent import ( build_subagent as build_teams_subagent, ) -from app.agents.multi_agent_chat.constants import ( - SUBAGENT_TO_REQUIRED_CONNECTOR_MAP, -) from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( read_md_file, ) @@ -105,6 +105,7 @@ SUBAGENT_BUILDERS_BY_NAME: dict[str, SubagentBuilder] = { "teams": build_teams_subagent, } + def _route_resource_package(builder: SubagentBuilder) -> str: mod = builder.__module__ return mod[: -len(".agent")] if mod.endswith(".agent") else mod.rsplit(".", 1)[0] diff --git a/surfsense_backend/app/agents/new_chat/agent_cache.py b/surfsense_backend/app/agents/new_chat/agent_cache.py new file mode 100644 index 000000000..fa8e6fb72 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/agent_cache.py @@ -0,0 +1,357 @@ +"""TTL-LRU cache for compiled SurfSense deep agents. + +Why this exists +--------------- + +``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat +turn: + +1. Discover connectors & document types from Postgres (~50-200ms) +2. Build the tool list (built-in + MCP) (~200ms-1.7s) +3. Compose the system prompt +4. Construct ~15 middleware instances (CPU) +5. Eagerly compile the general-purpose subagent + (``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously, + which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure + CPU work) +6. Compile the outer LangGraph + +For a single thread, all six steps produce the SAME object on every turn +unless the user has changed their LLM config, toggled a feature flag, +added a connector, etc. The right answer is to compile ONCE per +"agent shape" and reuse the resulting :class:`CompiledStateGraph` for +every subsequent turn on the same thread. + +Why a per-thread key (not a global pool) +---------------------------------------- + +Most middleware in the SurfSense stack captures per-thread state in +``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``, +``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse +would silently leak state across users and threads. Keying the cache on +``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated +turns on the same thread without changing any middleware's behavior. + +Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema` +(read via ``runtime.context``) so the cache can collapse to a single +``(llm_config_id, search_space_id, ...)`` key shared across threads. Until +then, per-thread keying is the only safe option. + +Cache shape +----------- + +* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30 + minutes — matches a typical chat session). ``maxsize`` (default 256) + caps memory; LRU evicts least-recently-used on overflow. +* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent + cold misses on the same key wait for the first build instead of + building N times. +* Process-local: this is an in-memory cache. Multi-replica deployments + pay the build cost once per replica per key. That's fine; the working + set per replica is small (one entry per active thread on that replica). + +Telemetry +--------- + +Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``: + + * ``hit`` — cache hit, microseconds-fast + * ``miss`` — first build for this key, includes build duration + * ``stale`` — entry was found but expired; rebuilt + * ``evict`` — LRU eviction (size-limited) + * ``size`` — current cache occupancy at lookup time +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import os +import time +from collections import OrderedDict +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from app.utils.perf import get_perf_logger + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + + +# --------------------------------------------------------------------------- +# Public API: signature helpers (cache key components) +# --------------------------------------------------------------------------- + + +def stable_hash(*parts: Any) -> str: + """Compute a deterministic SHA1 of the str repr of ``parts``. + + Used for cache key components that need a fixed-width representation + (system prompt, tool list, etc.). SHA1 is fine here — this is not a + security boundary, just a content fingerprint. + """ + h = hashlib.sha1(usedforsecurity=False) + for p in parts: + h.update(repr(p).encode("utf-8", errors="replace")) + h.update(b"\x1f") # ASCII unit separator between parts + return h.hexdigest() + + +def tools_signature( + tools: list[Any] | tuple[Any, ...], + *, + available_connectors: list[str] | None, + available_document_types: list[str] | None, +) -> str: + """Hash the bound-tool surface for cache-key purposes. + + The signature changes whenever: + + * A tool is added or removed from the bound list (built-in toggles, + MCP tools loaded for the user changes, gating rules flip, etc.). + * The available connectors / document types for the search space + change (new connector added, last connector removed, new document + type indexed). Because :func:`get_connector_gated_tools` derives + ``modified_disabled_tools`` from ``available_connectors``, the + tool surface is technically already covered — but we hash the + connector list separately so an empty-list "no tools changed" + situation still rotates the key when, say, the user re-adds a + connector that gates a tool we were already not exposing. + + Stays stable across: + + * Process restarts (tool names + descriptions are static). + * Different replicas (everyone gets the same hash for the same + inputs). + """ + tool_descriptors = sorted( + (getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools + ) + connectors = sorted(available_connectors or []) + doc_types = sorted(available_document_types or []) + return stable_hash(tool_descriptors, connectors, doc_types) + + +def flags_signature(flags: Any) -> str: + """Hash the resolved :class:`AgentFeatureFlags` dataclass. + + Frozen dataclasses are deterministically reprable, so a SHA1 of their + repr is a stable fingerprint. Restart safe (flags are read once at + process boot). + """ + return stable_hash(repr(flags)) + + +def system_prompt_hash(system_prompt: str) -> str: + """Hash a system prompt string. Cheap, ~30µs for typical prompts.""" + return hashlib.sha1( + system_prompt.encode("utf-8", errors="replace"), + usedforsecurity=False, + ).hexdigest() + + +# --------------------------------------------------------------------------- +# Cache implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _Entry: + value: Any + created_at: float + last_used_at: float + + +class _AgentCache: + """In-process TTL-LRU cache with per-key in-flight de-duplication. + + NOT THREAD-SAFE in the multithreading sense — designed for a single + asyncio event loop. Uvicorn runs one event loop per worker process, + so this is fine; multi-worker deployments simply each maintain their + own cache. + """ + + def __init__(self, *, maxsize: int, ttl_seconds: float) -> None: + self._maxsize = maxsize + self._ttl = ttl_seconds + self._entries: OrderedDict[str, _Entry] = OrderedDict() + # One lock per key — guards "build" so concurrent cold misses on + # the same key wait for the first build instead of all racing. + self._locks: dict[str, asyncio.Lock] = {} + + def _now(self) -> float: + return time.monotonic() + + def _is_fresh(self, entry: _Entry) -> bool: + return (self._now() - entry.created_at) < self._ttl + + def _evict_if_full(self) -> None: + while len(self._entries) >= self._maxsize: + evicted_key, _ = self._entries.popitem(last=False) + self._locks.pop(evicted_key, None) + _perf_log.info( + "[agent_cache] evict key=%s reason=lru size=%d", + _short(evicted_key), + len(self._entries), + ) + + def _touch(self, key: str, entry: _Entry) -> None: + entry.last_used_at = self._now() + self._entries.move_to_end(key, last=True) + + async def get_or_build( + self, + key: str, + *, + builder: Callable[[], Awaitable[Any]], + ) -> Any: + """Return the cached value for ``key`` or call ``builder()`` to make it. + + ``builder`` MUST be idempotent — concurrent cold misses on the + same key collapse to a single ``builder()`` call (the others + wait on the in-flight lock and observe the populated entry on + wake). + """ + # Fast path: hot hit. + entry = self._entries.get(key) + if entry is not None and self._is_fresh(entry): + self._touch(key, entry) + _perf_log.info( + "[agent_cache] hit key=%s age=%.1fs size=%d", + _short(key), + self._now() - entry.created_at, + len(self._entries), + ) + return entry.value + + # Stale entry — drop it; rebuild below. + if entry is not None and not self._is_fresh(entry): + _perf_log.info( + "[agent_cache] stale key=%s age=%.1fs ttl=%.0fs", + _short(key), + self._now() - entry.created_at, + self._ttl, + ) + self._entries.pop(key, None) + + # Slow path: serialize concurrent misses for the same key. + lock = self._locks.setdefault(key, asyncio.Lock()) + async with lock: + # Double-check after acquiring the lock — another waiter may + # have populated the entry while we slept. + entry = self._entries.get(key) + if entry is not None and self._is_fresh(entry): + self._touch(key, entry) + _perf_log.info( + "[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true", + _short(key), + self._now() - entry.created_at, + len(self._entries), + ) + return entry.value + + t0 = time.perf_counter() + try: + value = await builder() + except BaseException: + # Don't cache failed builds; let the next caller retry. + _perf_log.warning( + "[agent_cache] build_failed key=%s elapsed=%.3fs", + _short(key), + time.perf_counter() - t0, + ) + raise + elapsed = time.perf_counter() - t0 + + # Insert + evict. + self._evict_if_full() + now = self._now() + self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now) + self._entries.move_to_end(key, last=True) + _perf_log.info( + "[agent_cache] miss key=%s build=%.3fs size=%d", + _short(key), + elapsed, + len(self._entries), + ) + return value + + def invalidate(self, key: str) -> bool: + """Drop a single entry; return True if anything was removed.""" + removed = self._entries.pop(key, None) is not None + self._locks.pop(key, None) + if removed: + _perf_log.info( + "[agent_cache] invalidate key=%s size=%d", + _short(key), + len(self._entries), + ) + return removed + + def invalidate_prefix(self, prefix: str) -> int: + """Drop every entry whose key starts with ``prefix``. Returns count.""" + keys = [k for k in self._entries if k.startswith(prefix)] + for k in keys: + self._entries.pop(k, None) + self._locks.pop(k, None) + if keys: + _perf_log.info( + "[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d", + _short(prefix), + len(keys), + len(self._entries), + ) + return len(keys) + + def clear(self) -> None: + n = len(self._entries) + self._entries.clear() + self._locks.clear() + if n: + _perf_log.info("[agent_cache] clear removed=%d", n) + + def stats(self) -> dict[str, Any]: + return { + "size": len(self._entries), + "maxsize": self._maxsize, + "ttl_seconds": self._ttl, + } + + +def _short(key: str, n: int = 16) -> str: + """Truncate keys for log lines so they don't blow up log volume.""" + return key if len(key) <= n else f"{key[:n]}..." + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- + +_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256")) +_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800")) + +_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL) + + +def get_cache() -> _AgentCache: + """Return the process-wide compiled-agent cache singleton.""" + return _cache + + +def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache: + """Replace the singleton with a fresh cache. Tests only.""" + global _cache + _cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds) + return _cache + + +__all__ = [ + "flags_signature", + "get_cache", + "reload_for_tests", + "stable_hash", + "system_prompt_hash", + "tools_signature", +] diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index b3fe7850b..1f4024d9d 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent`` This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable subclass of the default ``FilesystemMiddleware`` — while preserving every other behaviour that ``create_deep_agent`` provides (todo-list, subagents, -summarisation, prompt-caching, etc.). +summarisation, etc.). Prompt caching is configured at LLM-build time via +``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather +than as a middleware. """ import asyncio @@ -33,12 +35,18 @@ from langchain.agents.middleware import ( TodoListMiddleware, ToolCallLimitMiddleware, ) -from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.agent_cache import ( + flags_signature, + get_cache, + stable_hash, + system_prompt_hash, + tools_signature, +) from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver @@ -52,6 +60,7 @@ from app.agents.new_chat.middleware import ( DedupHITLToolCallsMiddleware, DoomLoopMiddleware, FileIntentMiddleware, + FlattenSystemMessageMiddleware, KnowledgeBasePersistenceMiddleware, KnowledgePriorityMiddleware, KnowledgeTreeMiddleware, @@ -74,6 +83,7 @@ from app.agents.new_chat.plugin_loader import ( load_allowed_plugin_names_from_env, load_plugin_middlewares, ) +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.agents.new_chat.subagents import build_specialized_subagents from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, @@ -94,6 +104,39 @@ from app.utils.perf import get_perf_logger _perf_log = get_perf_logger() + +def _resolve_prompt_model_name( + agent_config: AgentConfig | None, + llm: BaseChatModel, +) -> str | None: + """Resolve the model id to feed to provider-variant detection. + + Preference order (matches the established idiom in + ``llm_router_service.py`` — see ``params.get("base_model") or + params.get("model", "")`` usages there): + + 1. ``agent_config.litellm_params["base_model"]`` — required for Azure + deployments where ``model_name`` is the deployment slug, not the + underlying family. Without this, a deployment named e.g. + ``"prod-chat-001"`` would silently miss every provider regex. + 2. ``agent_config.model_name`` — the user's configured model id. + 3. ``getattr(llm, "model", None)`` — fallback for direct callers that + don't supply an ``AgentConfig`` (currently a defensive path; all + production callers pass ``agent_config``). + + Returns ``None`` when nothing is available; ``compose_system_prompt`` + treats that as the ``"default"`` variant (no provider block emitted). + """ + if agent_config is not None: + params = agent_config.litellm_params or {} + base_model = params.get("base_model") + if isinstance(base_model, str) and base_model.strip(): + return base_model + if agent_config.model_name: + return agent_config.model_name + return getattr(llm, "model", None) + + # ============================================================================= # Connector Type Mapping # ============================================================================= @@ -279,6 +322,14 @@ async def create_surfsense_deep_agent( ) """ _t_agent_total = time.perf_counter() + + # Layer thread-aware prompt caching onto the LLM. Idempotent with the + # build-time call in ``llm_config.py``; this run merely adds + # ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family + # configs now that ``thread_id`` is known. No-op when ``thread_id`` is + # None or the provider is non-OpenAI-family. + apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) + filesystem_selection = filesystem_selection or FilesystemSelection() backend_resolver = build_backend_resolver( filesystem_selection, @@ -287,23 +338,39 @@ async def create_surfsense_deep_agent( else None, ) - # Discover available connectors and document types for this search space + # Discover available connectors and document types for this search space. + # + # NOTE: These two calls cannot be parallelized via ``asyncio.gather``. + # ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``); + # SQLAlchemy explicitly forbids concurrent operations on the same session + # ("This session is provisioning a new connection; concurrent operations + # are not permitted on the same session"). The Phase 1.4 in-process TTL + # cache in ``connector_service`` already collapses the warm path to a + # near-zero pair of dict lookups, so sequential awaits cost nothing in + # the common case while remaining correct on cold cache misses. available_connectors: list[str] | None = None available_document_types: list[str] | None = None _t0 = time.perf_counter() try: - connector_types = await connector_service.get_available_connectors( - search_space_id - ) - if connector_types: - available_connectors = _map_connectors_to_searchable_types(connector_types) + try: + connector_types_result = await connector_service.get_available_connectors( + search_space_id + ) + if connector_types_result: + available_connectors = _map_connectors_to_searchable_types( + connector_types_result + ) + except Exception as e: + logging.warning("Failed to discover available connectors: %s", e) - available_document_types = await connector_service.get_available_document_types( - search_space_id - ) - - except Exception as e: + try: + available_document_types = ( + await connector_service.get_available_document_types(search_space_id) + ) + except Exception as e: + logging.warning("Failed to discover available document types: %s", e) + except Exception as e: # pragma: no cover - defensive outer guard logging.warning(f"Failed to discover available connectors/document types: {e}") _perf_log.info( "[create_agent] Connector/doc-type discovery in %.3fs", @@ -398,6 +465,7 @@ async def create_surfsense_deep_agent( enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, mcp_connector_tools=_mcp_connector_tools, + model_name=_resolve_prompt_model_name(agent_config, llm), ) else: system_prompt = build_surfsense_system_prompt( @@ -405,6 +473,7 @@ async def create_surfsense_deep_agent( enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, mcp_connector_tools=_mcp_connector_tools, + model_name=_resolve_prompt_model_name(agent_config, llm), ) _perf_log.info( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 @@ -424,29 +493,77 @@ async def create_surfsense_deep_agent( # entire middleware build + main-graph compile into a single # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # event loop stays responsive. + # + # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed + # on every per-request value that any middleware in the stack closes + # over in ``__init__`` — drop one and you risk leaking state across + # threads. Hits collapse this whole block to a microsecond lookup; + # misses pay the original CPU cost AND populate the cache. + config_id = agent_config.config_id if agent_config is not None else None + + async def _build_agent() -> Any: + return await asyncio.to_thread( + _build_compiled_agent_blocking, + llm=llm, + tools=tools, + final_system_prompt=final_system_prompt, + backend_resolver=backend_resolver, + filesystem_mode=filesystem_selection.mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + visibility=visibility, + anon_session_id=anon_session_id, + available_connectors=available_connectors, + available_document_types=available_document_types, + # ``mentioned_document_ids`` is consumed by + # ``KnowledgePriorityMiddleware`` per turn via + # ``runtime.context`` (Phase 1.5). We still pass the + # caller-provided list here for the legacy fallback path + # (cache disabled / context not propagated) — the middleware + # drains its own copy after the first read so a cached graph + # never replays stale mentions. + mentioned_document_ids=mentioned_document_ids, + max_input_tokens=_max_input_tokens, + flags=_flags, + checkpointer=checkpointer, + ) + _t0 = time.perf_counter() - agent = await asyncio.to_thread( - _build_compiled_agent_blocking, - llm=llm, - tools=tools, - final_system_prompt=final_system_prompt, - backend_resolver=backend_resolver, - filesystem_mode=filesystem_selection.mode, - search_space_id=search_space_id, - user_id=user_id, - thread_id=thread_id, - visibility=visibility, - anon_session_id=anon_session_id, - available_connectors=available_connectors, - available_document_types=available_document_types, - mentioned_document_ids=mentioned_document_ids, - max_input_tokens=_max_input_tokens, - flags=_flags, - checkpointer=checkpointer, - ) + if _flags.enable_agent_cache and not _flags.disable_new_agent_stack: + # Cache key components — order matters only for human readability; + # the resulting hash is what's stored. Every component must + # rotate on a real shape change AND stay stable across identical + # invocations. + cache_key = stable_hash( + "v1", # schema version of the key — bump if components change + config_id, + thread_id, + user_id, + search_space_id, + visibility, + filesystem_selection.mode, + anon_session_id, + tools_signature( + tools, + available_connectors=available_connectors, + available_document_types=available_document_types, + ), + flags_signature(_flags), + system_prompt_hash(final_system_prompt), + _max_input_tokens, + # ``mentioned_document_ids`` deliberately omitted — middleware + # reads it from ``runtime.context`` (Phase 1.5). + ) + agent = await get_cache().get_or_build(cache_key, builder=_build_agent) + else: + agent = await _build_agent() _perf_log.info( - "[create_agent] Middleware stack + graph compiled in %.3fs", + "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)", time.perf_counter() - _t0, + "on" + if _flags.enable_agent_cache and not _flags.disable_new_agent_stack + else "off", ) _perf_log.info( @@ -568,7 +685,6 @@ def _build_compiled_agent_blocking( ), create_surfsense_compaction_middleware(llm, StateBackend), PatchToolCallsMiddleware(), - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key] @@ -998,6 +1114,14 @@ def _build_compiled_agent_blocking( noop_mw, retry_mw, fallback_mw, + # Coalesce a multi-text-block system message into one block + # immediately before the model call. Sits innermost on the + # system-message-mutation chain so it observes every appender + # (todo / filesystem / skills / subagents …) and prevents + # OpenRouter→Anthropic from redistributing ``cache_control`` + # across N blocks and tripping Anthropic's 4-breakpoint cap. + # See ``middleware/flatten_system.py`` for full rationale. + FlattenSystemMessageMiddleware(), # Tool-call repair must run after model emits but before # permission / dedup / doom-loop interpret the calls. repair_mw, @@ -1010,12 +1134,12 @@ def _build_compiled_agent_blocking( action_log_mw, PatchToolCallsMiddleware(), DedupHITLToolCallsMiddleware(agent_tools=list(tools)), - # Plugin slot — sits just before AnthropicCache so plugin-side - # transforms see the final tool result and run before any - # caching heuristics. Multiple plugins in declared order; loader - # filtered by the admin allowlist already. + # Plugin slot — sits at the tail so plugin-side transforms see the + # final tool result. Prompt caching is now applied at LLM build time + # via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no + # caching middleware is needed here. Multiple plugins run in declared + # order; loader filtered by the admin allowlist already. *plugin_middlewares, - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] deepagent_middleware = [m for m in deepagent_middleware if m is not None] diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py index c1fe45aaa..d720b524b 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/new_chat/context.py @@ -1,10 +1,25 @@ """ Context schema definitions for SurfSense agents. -This module defines the custom state schema used by the SurfSense deep agent. +This module defines the per-invocation context object passed to the SurfSense +deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6). + +The agent's compiled graph is the same across invocations (and cached by +``agent_cache``), so anything that varies per turn — the user mentions a +specific document, the front-end issues a unique ``request_id``, etc. — +MUST live on this context object instead of being captured into a +middleware ``__init__`` closure. Middlewares read fields back via +``runtime.context.``; tools read them via ``runtime.context``. + +This object is read inside both ``KnowledgePriorityMiddleware`` (for +``mentioned_document_ids``) and any future middleware that needs +per-request state without invalidating the compiled-agent cache. """ -from typing import NotRequired, TypedDict +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TypedDict class FileOperationContractState(TypedDict): @@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict): turn_id: str -class SurfSenseContextSchema(TypedDict): +@dataclass +class SurfSenseContextSchema: """ - Custom state schema for the SurfSense deep agent. + Per-invocation context for the SurfSense deep agent. - This extends the default agent state with custom fields. - The default state already includes: - - messages: Conversation history - - todos: Task list from TodoListMiddleware - - files: Virtual filesystem from FilesystemMiddleware + Defaults are chosen so the dataclass can be safely default-constructed + (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no + context is supplied — see ``langgraph.runtime.Runtime``). All fields + are optional; consumers must None-check before reading. - We're adding fields needed for knowledge base search: - - search_space_id: The user's search space ID - - db_session: Database session (injected at runtime) - - connector_service: Connector service instance (injected at runtime) + Phase 1.5 fields: + search_space_id: Search space the request is scoped to. + mentioned_document_ids: KB documents the user @-mentioned this turn. + Read by ``KnowledgePriorityMiddleware`` to seed its priority + list. Stays out of the compiled-agent cache key — that's the + whole point of putting it here. + file_operation_contract: One-shot file operation contract emitted + by ``FileIntentMiddleware`` for the upcoming turn. + turn_id / request_id: Correlation IDs surfaced by the streaming + task; populated for telemetry. + + Phase 2 will extend with: thread_id, user_id, visibility, + filesystem_mode, anon_session_id, available_connectors, + available_document_types, created_by_id (everything currently captured + by middleware ``__init__`` closures). """ - search_space_id: int - file_operation_contract: NotRequired[FileOperationContractState] - turn_id: NotRequired[str] - request_id: NotRequired[str] - # These are runtime-injected and won't be serialized - # db_session and connector_service are passed when invoking the agent + search_space_id: int | None = None + mentioned_document_ids: list[int] = field(default_factory=list) + file_operation_contract: FileOperationContractState | None = None + turn_id: str | None = None + request_id: str | None = None diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index c6524069a..b3dc0fa82 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack. These flags gate the newer agent middleware (some ported from OpenCode, some sourced from ``langchain.agents.middleware`` / ``deepagents``, some -SurfSense-native). They follow a "default-OFF for risky things, -default-ON for safe upgrades, master kill-switch for everything new" model. +SurfSense-native). Most shipped agent-stack upgrades default ON so Docker +image updates work even when older installs do not have newly introduced +environment variables. Risky/experimental integrations stay default OFF, +and the master kill-switch can still disable everything new. All new middleware checks its flag at agent build time. If the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new @@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior. Examples -------- -Local development (recommended for trying everything except doom-loop / selector): +Defaults: SURFSENSE_ENABLE_CONTEXT_EDITING=true SURFSENSE_ENABLE_COMPACTION_V2=true SURFSENSE_ENABLE_RETRY_AFTER=true + SURFSENSE_ENABLE_MODEL_FALLBACK=false + SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true + SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true - SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy - SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships - SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false - SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events + SURFSENSE_ENABLE_PERMISSION=true + SURFSENSE_ENABLE_DOOM_LOOP=true + SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call + SURFSENSE_ENABLE_STREAM_PARITY_V2=true Master kill-switch (overrides everything else): @@ -60,32 +65,28 @@ class AgentFeatureFlags: disable_new_agent_stack: bool = False # Agent quality — context budget, retry/limits, name-repair, doom-loop - enable_context_editing: bool = False - enable_compaction_v2: bool = False - enable_retry_after: bool = False + enable_context_editing: bool = True + enable_compaction_v2: bool = True + enable_retry_after: bool = True enable_model_fallback: bool = False - enable_model_call_limit: bool = False - enable_tool_call_limit: bool = False - enable_tool_call_repair: bool = False - enable_doom_loop: bool = ( - False # Default OFF until UI handles permission='doom_loop' - ) + enable_model_call_limit: bool = True + enable_tool_call_limit: bool = True + enable_tool_call_repair: bool = True + enable_doom_loop: bool = True # Safety — permissions, concurrency, tool-set narrowing - enable_permission: bool = False # Default OFF for first deploy - enable_busy_mutex: bool = False + enable_permission: bool = True + enable_busy_mutex: bool = True enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost # Skills + subagents - enable_skills: bool = False - enable_specialized_subagents: bool = False - enable_kb_planner_runnable: bool = False + enable_skills: bool = True + enable_specialized_subagents: bool = True + enable_kb_planner_runnable: bool = True # Snapshot / revert - enable_action_log: bool = False - enable_revert_route: bool = ( - False # Backend ships before UI; route returns 503 until this flips - ) + enable_action_log: bool = True + enable_revert_route: bool = True # Streaming parity v2 — opt in to LangChain's structured # ``AIMessageChunk`` content (typed reasoning blocks, tool-input @@ -94,7 +95,7 @@ class AgentFeatureFlags: # text path and the synthetic ``call_`` tool-call id (no # ``langchainToolCallId`` propagation). Schema migrations 135/136 # ship unconditionally because they're forward-compatible. - enable_stream_parity_v2: bool = False + enable_stream_parity_v2: bool = True # Plugins enable_plugin_loader: bool = False @@ -102,6 +103,41 @@ class AgentFeatureFlags: # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) enable_otel: bool = False + # Performance — compiled-agent cache (Phase 1 + Phase 2). + # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled + # graph if the cache key matches (LLM config + thread + tool surface + + # flags + system prompt + filesystem mode). Cuts per-turn agent-build + # wall clock from ~4-5s to <50µs on cache hits. + # + # SAFETY (Phase 2 unblocked this default-on): + # All connector mutation tools (``tools/notion``, ``tools/gmail``, + # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``, + # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``, + # ``tools/teams``, ``tools/luma``, ``connected_accounts``, + # ``update_memory``, ``search_surfsense_docs``) now acquire fresh + # short-lived ``AsyncSession`` instances per call via + # :data:`async_session_maker`. The factory still accepts ``db_session`` + # for registry compatibility but ``del``'s it immediately — see any + # of those files' factory docstrings for the rationale. The ``llm`` + # closure is per-(provider, model, config_id) which is already in + # the cache key, so the LLM is safe to share across cached hits of + # the same key. The KB priority middleware reads + # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5), + # not its constructor closure, so the same compiled agent serves + # turns with different mention lists correctly. + # + # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the + # environment if a regression surfaces. The path is exercised by + # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite. + enable_agent_cache: bool = True + # Phase 1 (deferred — measure first): pre-build & share the + # general-purpose subagent ``CompiledSubAgent`` across cold-cache + # misses. Only helps when the outer cache MISSES (cache hits already + # reuse the entire SubAgentMiddleware-compiled graph). Off by default + # until we have data showing cold misses are frequent enough to + # justify the extra global state. + enable_agent_cache_share_gp_subagent: bool = False + @classmethod def from_env(cls) -> AgentFeatureFlags: """Read flags from environment. @@ -115,48 +151,76 @@ class AgentFeatureFlags: "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " "middleware is forced OFF for this build." ) - return cls(disable_new_agent_stack=True) + return cls( + disable_new_agent_stack=True, + enable_context_editing=False, + enable_compaction_v2=False, + enable_retry_after=False, + enable_model_fallback=False, + enable_model_call_limit=False, + enable_tool_call_limit=False, + enable_tool_call_repair=False, + enable_doom_loop=False, + enable_permission=False, + enable_busy_mutex=False, + enable_llm_tool_selector=False, + enable_skills=False, + enable_specialized_subagents=False, + enable_kb_planner_runnable=False, + enable_action_log=False, + enable_revert_route=False, + enable_stream_parity_v2=False, + enable_plugin_loader=False, + enable_otel=False, + enable_agent_cache=False, + enable_agent_cache_share_gp_subagent=False, + ) return cls( disable_new_agent_stack=False, # Agent quality - enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), - enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), - enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), + enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True), + enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True), + enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), enable_model_call_limit=_env_bool( - "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True ), - enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), + enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True), enable_tool_call_repair=_env_bool( - "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True ), - enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), + enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True), # Safety - enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), - enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), + enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True), + enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True), enable_llm_tool_selector=_env_bool( "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False ), # Skills + subagents - enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), + enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True), enable_specialized_subagents=_env_bool( - "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True ), enable_kb_planner_runnable=_env_bool( - "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True ), # Snapshot / revert - enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), - enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True), + enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True), # Streaming parity v2 enable_stream_parity_v2=_env_bool( - "SURFSENSE_ENABLE_STREAM_PARITY_V2", False + "SURFSENSE_ENABLE_STREAM_PARITY_V2", True ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), + # Performance + enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True), + enable_agent_cache_share_gp_subagent=_env_bool( + "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False + ), ) def any_new_middleware_enabled(self) -> bool: diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 58d8f84d0..bc37bf1c4 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -27,6 +27,7 @@ from litellm import get_model_info from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -89,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "XAI": "xai", - "BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "COMETAPI": "cometapi", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} +# Provider mapping for LiteLLM model string construction. +# +# Single source of truth lives in +# :mod:`app.services.provider_capabilities` so the YAML loader (which +# runs during ``app.config`` class-body init) can resolve provider +# prefixes without dragging the agent / tools tree into module load +# order. Re-exported here under the historical ``PROVIDER_MAP`` name +# so existing callers (``llm_router_service``, ``image_gen_router_service``, +# tests) keep working unchanged. +from app.services.provider_capabilities import ( # noqa: E402 + _PROVIDER_PREFIX_MAP as PROVIDER_MAP, +) def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: @@ -177,6 +155,17 @@ class AgentConfig: anonymous_enabled: bool = False quota_reserve_tokens: int | None = None + # Capability flag: best-effort True for the chat selector / catalog. + # Resolved via :func:`provider_capabilities.derive_supports_image_input` + # which prefers OpenRouter's ``architecture.input_modalities`` and + # otherwise consults LiteLLM's authoritative model map. Default True + # is the conservative-allow stance — the streaming-task safety net + # (``is_known_text_only_chat_model``) is the *only* place a False + # actually blocks a request. Setting this to False here without an + # authoritative source would silently hide vision-capable models + # (the regression we're fixing). + supports_image_input: bool = True + @classmethod def from_auto_mode(cls) -> "AgentConfig": """ @@ -202,6 +191,12 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # Auto routes across the configured pool, which usually + # contains at least one vision-capable deployment; the router + # will surface a 404 from a non-vision deployment as a normal + # ``allowed_fails`` event and fail over rather than blocking + # the request outright. + supports_image_input=True, ) @classmethod @@ -215,10 +210,24 @@ class AgentConfig: Returns: AgentConfig instance """ - return cls( - provider=config.provider.value + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + + provider_value = ( + config.provider.value if hasattr(config.provider, "value") - else str(config.provider), + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + return cls( + provider=provider_value, model_name=config.model_name, api_key=config.api_key, api_base=config.api_base, @@ -234,6 +243,16 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # BYOK rows have no operator-curated capability flag, so we + # ask LiteLLM (default-allow on unknown). The streaming + # safety net still blocks if the model is *explicitly* + # marked text-only. + supports_image_input=derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ), ) @classmethod @@ -252,15 +271,46 @@ class AgentConfig: Returns: AgentConfig instance """ + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + # Get system instructions from YAML, default to empty string system_instructions = yaml_config.get("system_instructions", "") + provider = yaml_config.get("provider", "").upper() + model_name = yaml_config.get("model_name", "") + custom_provider = yaml_config.get("custom_provider") + litellm_params = yaml_config.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + # Explicit YAML override wins; otherwise derive from LiteLLM / + # OpenRouter modalities. The YAML loader already populates this + # field, but this method is also called from + # ``load_global_llm_config_by_id``'s file fallback (hot reload), + # so we re-derive here for safety. The bool() coercion preserves + # the loader's behaviour for explicit ``true`` / ``false`` + # strings that PyYAML may surface. + if "supports_image_input" in yaml_config: + supports_image_input = bool(yaml_config.get("supports_image_input")) + else: + supports_image_input = derive_supports_image_input( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ) + return cls( - provider=yaml_config.get("provider", "").upper(), - model_name=yaml_config.get("model_name", ""), + provider=provider, + model_name=model_name, api_key=yaml_config.get("api_key", ""), api_base=yaml_config.get("api_base"), - custom_provider=yaml_config.get("custom_provider"), + custom_provider=custom_provider, litellm_params=yaml_config.get("litellm_params"), # Prompt configuration from YAML (with defaults for backwards compatibility) system_instructions=system_instructions if system_instructions else None, @@ -275,6 +325,7 @@ class AgentConfig: is_premium=yaml_config.get("billing_tier", "free") == "premium", anonymous_enabled=yaml_config.get("anonymous_enabled", False), quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"), + supports_image_input=supports_image_input, ) @@ -494,6 +545,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) + # Configure LiteLLM-native prompt caching (cache_control_injection_points + # for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.). + # ``agent_config=None`` here — the YAML path doesn't have provider intent + # in a structured form, so we set only the universal injection points. + apply_litellm_prompt_caching(llm) return llm @@ -518,7 +574,16 @@ def create_chat_litellm_from_agent_config( print("Error: Auto mode requested but LLM Router not initialized") return None try: - return get_auto_mode_llm() + router_llm = get_auto_mode_llm() + if router_llm is not None: + # Universal cache_control_injection_points only — auto-mode + # fans out across providers, so OpenAI-only kwargs (e.g. + # ``prompt_cache_key``) are left off here. ``drop_params`` + # would strip them at the provider boundary anyway, but + # there's no point setting them when we don't know the + # destination. + apply_litellm_prompt_caching(router_llm, agent_config=agent_config) + return router_llm except Exception as e: print(f"Error creating ChatLiteLLMRouter: {e}") return None @@ -549,4 +614,9 @@ def create_chat_litellm_from_agent_config( llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) + # Build-time prompt caching: sets ``cache_control_injection_points`` for + # all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``. + # Per-thread ``prompt_cache_key`` is layered on later in + # ``create_surfsense_deep_agent`` once ``thread_id`` is known. + apply_litellm_prompt_caching(llm, agent_config=agent_config) return llm diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 094c102f8..6742bd8de 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import ( from app.agents.new_chat.middleware.filesystem import ( SurfSenseFilesystemMiddleware, ) +from app.agents.new_chat.middleware.flatten_system import ( + FlattenSystemMessageMiddleware, +) from app.agents.new_chat.middleware.kb_persistence import ( KnowledgeBasePersistenceMiddleware, commit_staged_filesystem_state, @@ -61,6 +64,7 @@ __all__ = [ "DedupHITLToolCallsMiddleware", "DoomLoopMiddleware", "FileIntentMiddleware", + "FlattenSystemMessageMiddleware", "KnowledgeBasePersistenceMiddleware", "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index 4b5ad546d..e7d9b8f75 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -33,6 +33,7 @@ from __future__ import annotations import asyncio import logging +import time import weakref from typing import Any @@ -58,6 +59,11 @@ class _ThreadLockManager: weakref.WeakValueDictionary() ) self._cancel_events: dict[str, asyncio.Event] = {} + self._cancel_requested_at_ms: dict[str, int] = {} + self._cancel_attempt_count: dict[str, int] = {} + # Monotonic per-thread epoch used to prevent stale middleware + # teardown from releasing a newer turn's lock. + self._turn_epoch: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -76,14 +82,57 @@ class _ThreadLockManager: def request_cancel(self, thread_id: str) -> bool: event = self._cancel_events.get(thread_id) if event is None: - return False + event = asyncio.Event() + self._cancel_events[thread_id] = event event.set() + now_ms = int(time.time() * 1000) + self._cancel_requested_at_ms[thread_id] = now_ms + self._cancel_attempt_count[thread_id] = ( + self._cancel_attempt_count.get(thread_id, 0) + 1 + ) return True + def is_cancel_requested(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + return bool(event and event.is_set()) + + def cancel_state(self, thread_id: str) -> tuple[int, int] | None: + if not self.is_cancel_requested(thread_id): + return None + attempts = self._cancel_attempt_count.get(thread_id, 1) + requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0) + return attempts, requested_at_ms + def reset(self, thread_id: str) -> None: event = self._cancel_events.get(thread_id) if event is not None: event.clear() + self._cancel_requested_at_ms.pop(thread_id, None) + self._cancel_attempt_count.pop(thread_id, None) + + def bump_turn_epoch(self, thread_id: str) -> int: + epoch = self._turn_epoch.get(thread_id, 0) + 1 + self._turn_epoch[thread_id] = epoch + return epoch + + def current_turn_epoch(self, thread_id: str) -> int: + return self._turn_epoch.get(thread_id, 0) + + def end_turn(self, thread_id: str) -> None: + """Best-effort terminal cleanup for a thread turn. + + This is intentionally idempotent and safe to call from outer stream + finally-blocks where middleware teardown might be skipped due to abort + or disconnect edge-cases. + """ + # Invalidate any in-flight middleware holder first. This guarantees a + # stale ``aafter_agent`` from an older attempt cannot unlock a newer + # retry that already acquired the lock for the same thread. + self.bump_turn_epoch(thread_id) + lock = self._locks.get(thread_id) + if lock is not None and lock.locked(): + lock.release() + self.reset(thread_id) def release(self, thread_id: str) -> bool: """Force-release the per-thread lock; safety-net for turns that end before ``__end__``. @@ -115,18 +164,28 @@ def get_cancel_event(thread_id: str) -> asyncio.Event: def request_cancel(thread_id: str) -> bool: - """Trip the cancel event for ``thread_id``. Returns True if found.""" + """Trip the cancel event for ``thread_id``. Always returns True.""" return manager.request_cancel(thread_id) +def is_cancel_requested(thread_id: str) -> bool: + """Return whether ``thread_id`` currently has a pending cancel signal.""" + return manager.is_cancel_requested(thread_id) + + +def get_cancel_state(thread_id: str) -> tuple[int, int] | None: + """Return ``(attempt_count, requested_at_ms)`` for pending cancel state.""" + return manager.cancel_state(thread_id) + + def reset_cancel(thread_id: str) -> None: """Reset the cancel event for ``thread_id`` (called between turns).""" manager.reset(thread_id) -def release_lock(thread_id: str) -> bool: - """Force-release the per-thread busy lock; safe to call when not held.""" - return manager.release(thread_id) +def end_turn(thread_id: str) -> None: + """Force end-of-turn cleanup for lock + cancel state.""" + manager.end_turn(thread_id) class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): @@ -151,10 +210,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo super().__init__() self._require_thread_id = require_thread_id self.tools = [] - # Per-call locks owned by this middleware. We track them as - # an instance attribute so ``aafter_agent`` knows which lock - # to release. - self._held_locks: dict[str, asyncio.Lock] = {} + # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent`` + # only releases when its epoch still matches the manager's current + # epoch for the thread, preventing stale unlock races. + self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {} @staticmethod def _thread_id(runtime: Runtime[ContextT]) -> str | None: @@ -205,7 +264,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo if lock.locked(): raise BusyError(request_id=thread_id) await lock.acquire() - self._held_locks[thread_id] = lock + epoch = manager.bump_turn_epoch(thread_id) + self._held_locks[thread_id] = (lock, epoch) # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -219,8 +279,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo thread_id = self._thread_id(runtime) if thread_id is None: return None - lock = self._held_locks.pop(thread_id, None) - if lock is not None and lock.locked(): + held = self._held_locks.pop(thread_id, None) + if held is None: + return None + lock, held_epoch = held + if held_epoch != manager.current_turn_epoch(thread_id): + # Stale teardown from an older attempt (e.g. runtime-recovery path + # already advanced epoch). Do not touch current lock/cancel state. + return None + if lock.locked(): lock.release() # Always clear cancel event between turns so a stale signal # doesn't leak into the next request. @@ -251,9 +318,11 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo __all__ = [ "BusyMutexMiddleware", + "end_turn", "get_cancel_event", + "get_cancel_state", + "is_cancel_requested", "manager", - "release_lock", "request_cancel", "reset_cancel", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py new file mode 100644 index 000000000..29cd57aa0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py @@ -0,0 +1,233 @@ +r"""Coalesce multi-block system messages into a single text block. + +Several middlewares in our deepagent stack each call +``append_to_system_message`` on the way down to the model +(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``, +``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the +request reaches the LLM, the system message has 5+ separate text blocks. + +Anthropic enforces a hard cap of **4 ``cache_control`` blocks per +request**, and we configure 2 injection points +(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting +the prepended ``request.system_message``, this middleware is the +defensive partner: it guarantees that "the system block" is *one* +content block, so LiteLLM's ``AnthropicCacheControlHook`` and any +OpenRouter→Anthropic transformer can never multiply our budget into +several breakpoints by spreading ``cache_control`` across multiple +text blocks of a multi-block system content. + +Without flattening we used to see:: + + OpenrouterException - {"error":{"message":"Provider returned error", + "code":400,"metadata":{"raw":"...A maximum of 4 blocks with + cache_control may be provided. Found 5."}}} + +(Same error class documented in +https://github.com/BerriAI/litellm/issues/15696 and +https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix +in PR #15395 covers the litellm transformer but does not protect us +when the OpenRouter SaaS itself does the redistribution.) + +A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching +the first injection point from ``role: system`` to ``index: 0``) +neutralises the *primary* cause of the same 400 — multiple +``SystemMessage``\ s injected by ``before_agent`` middlewares +(priority/tree/memory/file-intent/anonymous-doc) accumulating across +turns, each tagged with ``cache_control`` by the ``role: system`` +matcher. This middleware remains useful as defence-in-depth against +the multi-block redistribution path. + +Placement: innermost on the system-message-mutation chain, after every +appender (``todo``/``filesystem``/``skills``/``subagents``) and after +summarization, but before ``noop``/``retry``/``fallback`` so each retry +attempt sees a flattened payload. See ``chat_deepagent.py``. + +Idempotent: a string-content system message is left untouched. A list +that contains anything other than plain text blocks (e.g. an image) is +also left untouched — those are rare on system messages and we'd lose +the non-text payload by joining. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.messages import SystemMessage + +logger = logging.getLogger(__name__) + + +def _flatten_text_blocks(content: list[Any]) -> str | None: + """Return joined text if every block is a plain ``{"type": "text"}``. + + Returns ``None`` when the list contains anything that isn't a text + block we can safely concatenate (image, audio, file, non-standard + blocks, dicts with extra non-cache_control fields). The caller + leaves the original content untouched in that case rather than + silently dropping payload. + + ``cache_control`` on individual blocks is intentionally discarded — + the whole point of flattening is to let LiteLLM's + ``cache_control_injection_points`` re-place a single breakpoint on + the resulting one-block system content. + """ + chunks: list[str] = [] + for block in content: + if isinstance(block, str): + chunks.append(block) + continue + if not isinstance(block, dict): + return None + if block.get("type") != "text": + return None + text = block.get("text") + if not isinstance(text, str): + return None + chunks.append(text) + return "\n\n".join(chunks) + + +def _flattened_request( + request: ModelRequest[ContextT], +) -> ModelRequest[ContextT] | None: + """Return a request with system_message flattened, or ``None`` for no-op.""" + sys_msg = request.system_message + if sys_msg is None: + return None + content = sys_msg.content + if not isinstance(content, list) or len(content) <= 1: + return None + + flattened = _flatten_text_blocks(content) + if flattened is None: + return None + + new_sys = SystemMessage( + content=flattened, + additional_kwargs=dict(sys_msg.additional_kwargs), + response_metadata=dict(sys_msg.response_metadata), + ) + if sys_msg.id is not None: + new_sys.id = sys_msg.id + return request.override(system_message=new_sys) + + +def _diagnostic_summary(request: ModelRequest[Any]) -> str: + """One-line dump of cache_control-relevant request shape. + + Temporary diagnostic to prove where the ``Found N`` cache_control + breakpoints are coming from when Anthropic 400s. Removed once the + root cause is confirmed and a fix is in place. + """ + sys_msg = request.system_message + if sys_msg is None: + sys_shape = "none" + elif isinstance(sys_msg.content, str): + sys_shape = f"str(len={len(sys_msg.content)})" + elif isinstance(sys_msg.content, list): + sys_shape = f"list(blocks={len(sys_msg.content)})" + else: + sys_shape = f"other({type(sys_msg.content).__name__})" + + role_hist: list[str] = [] + multi_block_msgs = 0 + msgs_with_cc = 0 + sys_msgs_in_history = 0 + for m in request.messages: + mtype = getattr(m, "type", type(m).__name__) + role_hist.append(mtype) + if isinstance(m, SystemMessage): + sys_msgs_in_history += 1 + c = getattr(m, "content", None) + if isinstance(c, list): + multi_block_msgs += 1 + for blk in c: + if isinstance(blk, dict) and "cache_control" in blk: + msgs_with_cc += 1 + break + if "cache_control" in getattr(m, "additional_kwargs", {}) or {}: + msgs_with_cc += 1 + + tools = request.tools or [] + tools_with_cc = 0 + for t in tools: + if isinstance(t, dict) and ( + "cache_control" in t or "cache_control" in t.get("function", {}) + ): + tools_with_cc += 1 + + return ( + f"sys={sys_shape} msgs={len(request.messages)} " + f"sys_msgs_in_history={sys_msgs_in_history} " + f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} " + f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} " + f"roles={role_hist[-8:]}" + ) + + +class FlattenSystemMessageMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): + """Collapse a multi-text-block system message to a single string. + + Sits innermost on the system-message-mutation chain so it observes + every middleware's contribution. Has no other side effect — the + body of every block is preserved, just joined with ``"\\n\\n"``. + """ + + def __init__(self) -> None: + super().__init__() + self.tools = [] + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> Any: + if logger.isEnabledFor(logging.DEBUG): + logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) + flattened = _flattened_request(request) + if flattened is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[flatten_system] collapsed %d system blocks to one", + len(request.system_message.content), # type: ignore[arg-type, union-attr] + ) + return handler(flattened) + return handler(request) + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> Any: + if logger.isEnabledFor(logging.DEBUG): + logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) + flattened = _flattened_request(request) + if flattened is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[flatten_system] collapsed %d system blocks to one", + len(request.system_message.content), # type: ignore[arg-type, union-attr] + ) + return await handler(flattened) + return await handler(request) + + +__all__ = [ + "FlattenSystemMessageMiddleware", + "_flatten_text_blocks", + "_flattened_request", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 0820e8c3e..ee5c1d182 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] state: AgentState, runtime: Runtime[Any], ) -> dict[str, Any] | None: - del runtime if self.filesystem_mode != FilesystemMode.CLOUD: return None @@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] if anon_doc: return self._anon_priority(state, anon_doc) - return await self._authenticated_priority(state, messages, user_text) + return await self._authenticated_priority(state, messages, user_text, runtime) def _anon_priority( self, @@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] state: AgentState, messages: Sequence[BaseMessage], user_text: str, + runtime: Runtime[Any] | None = None, ) -> dict[str, Any]: t0 = asyncio.get_event_loop().time() ( @@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) + # Per-turn ``mentioned_document_ids`` flow: + # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the + # streaming task supplies a fresh :class:`SurfSenseContextSchema` + # on every ``astream_events`` call, so this list is naturally + # scoped to the current turn. Allows cross-turn graph reuse via + # ``agent_cache``. + # 2. Legacy fallback (cache disabled / context not propagated): the + # constructor-injected ``self.mentioned_document_ids`` list. We + # drain it after the first read so a cached graph (no Phase 1.5 + # wiring) doesn't keep replaying the same mentions on every + # turn. + # + # CRITICAL: distinguish "context absent" (legacy caller, no field at + # all) from "context provided but empty" (turn with no mentions). + # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in + # Python, so a naive ``if ctx_mentions:`` would fall through to the + # legacy closure on every no-mention follow-up turn — replaying the + # mentions baked in by turn 1's cache-miss build. Always drain the + # closure once the runtime path has fired so a cached middleware + # instance can never resurrect stale state. + mention_ids: list[int] = [] + ctx = getattr(runtime, "context", None) if runtime is not None else None + ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None + if ctx_mentions is not None: + # Runtime path is authoritative — even an empty list means + # "this turn has no mentions", NOT "look at the closure". + mention_ids = list(ctx_mentions) + if self.mentioned_document_ids: + self.mentioned_document_ids = [] + elif self.mentioned_document_ids: + mention_ids = list(self.mentioned_document_ids) + self.mentioned_document_ids = [] + mentioned_results: list[dict[str, Any]] = [] - if self.mentioned_document_ids: + if mention_ids: mentioned_results = await fetch_mentioned_documents( - document_ids=self.mentioned_document_ids, + document_ids=mention_ids, search_space_id=self.search_space_id, ) - self.mentioned_document_ids = [] if is_recency: doc_types = _resolve_search_types( diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py new file mode 100644 index 000000000..9fe47cdac --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py @@ -0,0 +1,188 @@ +r"""LiteLLM-native prompt caching configuration for SurfSense agents. + +Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never +activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` +gate always failed) with LiteLLM's universal caching mechanism. + +Coverage: + +- Marker-based providers (need ``cache_control`` injection, which LiteLLM + performs automatically when ``cache_control_injection_points`` is set): + ``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``, + ``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/`` + (Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM). +- Auto-cached (LiteLLM strips the marker silently): ``openai/``, + ``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024 + tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``. + +We inject **two** breakpoints per request: + +- ``index: 0`` — pins the SurfSense system prompt at the head of the + request (provider variant, citation rules, tool catalog, KB tree, + skills metadata). The langchain agent factory always prepends + ``request.system_message`` at index 0 (see ``factory.py`` + ``_execute_model_async``), so this targets exactly the main system + prompt regardless of how many other ``SystemMessage``\ s the + ``before_agent`` injectors (priority, tree, memory, file-intent, + anonymous-doc) have inserted into ``state["messages"]``. Using + ``role: system`` here would apply ``cache_control`` to **every** + system-role message and trip Anthropic's hard cap of 4 cache + breakpoints per request once the conversation accumulates enough + injected system messages — which surfaces as the upstream 400 + ``A maximum of 4 blocks with cache_control may be provided. Found N`` + via OpenRouter→Anthropic. +- ``index: -1`` — pins the latest message so multi-turn savings compound: + Anthropic-family providers use longest-matching-prefix lookup, so turn + N+1 still reads turn N's cache up to the shared prefix. + +For OpenAI-family configs we additionally pass: + +- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that + raises hit rate by sending requests with a shared prefix to the same + backend. +- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default + 5-10 min in-memory cache. + +Safety net: ``litellm.drop_params=True`` is set globally in +``app.services.llm_service`` at module-load time. Any kwarg the destination +provider doesn't recognise is auto-stripped at the provider transformer +layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on +``prompt_cache_key`` etc. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from langchain_core.language_models import BaseChatModel + +if TYPE_CHECKING: + from app.agents.new_chat.llm_config import AgentConfig + +logger = logging.getLogger(__name__) + + +# Two-breakpoint policy: head-of-request + latest message. See module +# docstring for rationale. Anthropic caps requests at 4 ``cache_control`` +# blocks; we use 2 here, leaving headroom for Phase-2 tool caching. +# +# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's +# ``before_agent`` middlewares (priority, tree, memory, file-intent, +# anonymous-doc) insert ``SystemMessage`` instances into +# ``state["messages"]`` that accumulate across turns. With +# ``role: system`` the LiteLLM hook would tag *every* one of them with +# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0`` +# always targets the langchain-prepended ``request.system_message`` +# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text +# block), giving us exactly one stable cache breakpoint. +_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( + {"location": "message", "index": 0}, + {"location": "message", "index": -1}, +) + +# Providers (uppercase ``AgentConfig.provider`` values) that natively expose +# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and +# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers +# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without +# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU, +# MINIMAX), so we can't infer family from the litellm prefix alone. +_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"}) + + +def _is_router_llm(llm: BaseChatModel) -> bool: + """Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import. + + Importing ``app.services.llm_router_service`` at module-load time would + create a cycle via ``llm_config -> prompt_caching -> llm_router_service``. + Class-name comparison is sufficient since the class is defined in a + single place. + """ + return type(llm).__name__ == "ChatLiteLLMRouter" + + +def _is_openai_family_config(agent_config: AgentConfig | None) -> bool: + """Whether the config targets an OpenAI-style prompt-cache surface. + + Strict — only returns True when the user explicitly chose OPENAI, + DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` / + ``YAMLConfig``. Auto-mode and custom providers return False because + we can't statically know the destination. + """ + if agent_config is None or not agent_config.provider: + return False + if agent_config.is_auto_mode: + return False + if agent_config.custom_provider: + return False + return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS + + +def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None: + """Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail. + + Initialises the field to ``{}`` when present-but-None on a Pydantic v2 + model. Returns ``None`` if the LLM type doesn't expose a writable + ``model_kwargs`` attribute (caller should treat as no-op). + """ + model_kwargs = getattr(llm, "model_kwargs", None) + if isinstance(model_kwargs, dict): + return model_kwargs + try: + llm.model_kwargs = {} # type: ignore[attr-defined] + except Exception: + return None + refreshed = getattr(llm, "model_kwargs", None) + return refreshed if isinstance(refreshed, dict) else None + + +def apply_litellm_prompt_caching( + llm: BaseChatModel, + *, + agent_config: AgentConfig | None = None, + thread_id: int | None = None, +) -> None: + """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. + + Idempotent — values already present in ``llm.model_kwargs`` (e.g. from + ``agent_config.litellm_params`` overrides) are preserved. Mutates + ``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion`` + via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge + in our custom ``ChatLiteLLMRouter``. + + Args: + llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance. + agent_config: Optional ``AgentConfig`` driving provider-specific + behaviour. When omitted (or auto-mode), only the universal + ``cache_control_injection_points`` are set. + thread_id: Optional thread id used to construct a per-thread + ``prompt_cache_key`` for OpenAI-family providers. Caching still + works without it (server-side automatic), but the key improves + backend routing affinity and therefore hit rate. + """ + model_kwargs = _get_or_init_model_kwargs(llm) + if model_kwargs is None: + logger.debug( + "apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping", + type(llm).__name__, + ) + return + + if "cache_control_injection_points" not in model_kwargs: + model_kwargs["cache_control_injection_points"] = [ + dict(point) for point in _DEFAULT_INJECTION_POINTS + ] + + # OpenAI-family extras only when we statically know the destination is + # OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers + # so we can't safely set OpenAI-only kwargs there (drop_params would + # strip them but it's wasteful to set them in the first place). + if _is_router_llm(llm): + return + if not _is_openai_family_config(agent_config): + return + + if thread_id is not None and "prompt_cache_key" not in model_kwargs: + model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}" + if "prompt_cache_retention" not in model_kwargs: + model_kwargs["prompt_cache_retention"] = "24h" diff --git a/surfsense_backend/app/agents/new_chat/subagents/constants.py b/surfsense_backend/app/agents/new_chat/subagents/constants.py index ef5a33e22..cb1da499b 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/constants.py +++ b/surfsense_backend/app/agents/new_chat/subagents/constants.py @@ -27,14 +27,9 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset( { # Exact tool names from shared deny patterns. - *{ - name - for name in WRITE_TOOL_DENY_PATTERNS - if "*" not in name - }, + *{name for name in WRITE_TOOL_DENY_PATTERNS if "*" not in name}, # Additional non-provider state mutation controls. "write_todos", "task", } ) - diff --git a/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py b/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py index 238b13e8e..da332fe28 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py +++ b/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py @@ -112,10 +112,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any: Rule(permission=name, pattern="*", action="deny") for name in NON_PROVIDER_STATE_MUTATION_DENY ) - rules.extend( - Rule(permission=name, pattern="*", action="ask") - for name in ask_tools - ) + rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools) return PermissionMiddleware( rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")] ) @@ -163,4 +160,3 @@ def build_linear_specialist_subagent( if model is not None: spec["model"] = model return spec # type: ignore[return-value] - diff --git a/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py b/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py index b72edeee8..90ca80152 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py +++ b/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py @@ -119,10 +119,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any: Rule(permission=name, pattern="*", action="deny") for name in NON_PROVIDER_STATE_MUTATION_DENY ) - rules.extend( - Rule(permission=name, pattern="*", action="ask") - for name in ask_tools - ) + rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools) return PermissionMiddleware( rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")] ) @@ -171,4 +168,3 @@ def build_slack_specialist_subagent( if model is not None: spec["model"] = model return spec # type: ignore[return-value] - diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py index 095413bdb..c56db1528 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_create_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the create_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def create_confluence_page( title: str, @@ -42,160 +60,163 @@ def create_create_confluence_page_tool( """ logger.info(f"create_confluence_page called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - return {"status": "error", "message": context["error"]} + if "error" in context: + return {"status": "error", "message": context["error"]} - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Confluence accounts need re-authentication.", - "connector_type": "confluence", - } + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Confluence accounts need re-authentication.", + "connector_type": "confluence", + } - result = request_approval( - action_type="confluence_page_creation", - tool_name="create_confluence_page", - params={ - "title": title, - "content": content, - "space_id": space_id, - "connector_id": connector_id, - }, - context=context, - ) + result = request_approval( + action_type="confluence_page_creation", + tool_name="create_confluence_page", + params={ + "title": title, + "content": content, + "space_id": space_id, + "connector_id": connector_id, + }, + context=context, + ) - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) or "" - final_space_id = result.params.get("space_id", space_id) - final_connector_id = result.params.get("connector_id", connector_id) + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) or "" + final_space_id = result.params.get("space_id", space_id) + final_connector_id = result.params.get("connector_id", connector_id) - if not final_title or not final_title.strip(): - return {"status": "error", "message": "Page title cannot be empty."} - if not final_space_id: - return {"status": "error", "message": "A space must be selected."} + if not final_title or not final_title.strip(): + return {"status": "error", "message": "Page title cannot be empty."} + if not final_space_id: + return {"status": "error", "message": "A space must be selected."} - from sqlalchemy.future import select + from sqlalchemy.future import select - from app.db import SearchSourceConnector, SearchSourceConnectorType + from app.db import SearchSourceConnector, SearchSourceConnectorType - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Confluence connector found.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - api_result = await client.create_page( - space_id=final_space_id, - title=final_title, - body=final_content, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - page_id = str(api_result.get("id", "")) - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=page_id, - page_title=final_title, - space_id=final_space_id, - body_content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Confluence connector found.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } - return { - "status": "success", - "page_id": page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", - } + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=actual_connector_id + ) + api_result = await client.create_page( + space_id=final_space_id, + title=final_title, + body=final_content, + ) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + _conn = connector + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + page_id = str(api_result.get("id", "")) + page_links = ( + api_result.get("_links", {}) if isinstance(api_result, dict) else {} + ) + page_url = "" + if page_links.get("base") and page_links.get("webui"): + page_url = f"{page_links['base']}{page_links['webui']}" + + kb_message_suffix = "" + try: + from app.services.confluence import ConfluenceKBSyncService + + kb_service = ConfluenceKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + page_id=page_id, + page_title=final_title, + space_id=final_space_id, + body_content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "page_id": page_id, + "page_url": page_url, + "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py index 7c03c2760..d4cd5032f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_delete_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the delete_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def delete_confluence_page( page_title_or_id: str, @@ -43,137 +61,143 @@ def create_delete_confluence_page_tool( f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, page_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - page_data = context["page"] - page_id = page_data["page_id"] - page_title = page_data.get("page_title", "") - document_id = page_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_deletion", - tool_name="delete_confluence_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, page_title_or_id ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "confluence", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + page_data = context["page"] + page_id = page_data["page_id"] + page_title = page_data.get("page_title", "") + document_id = page_data["document_id"] + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="confluence_page_deletion", + tool_name="delete_confluence_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - await client.delete_page(final_page_id) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this page.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } - message = f"Confluence page '{page_title}' deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + await client.delete_page(final_page_id) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - return { - "status": "success", - "page_id": final_page_id, - "deleted_from_kb": deleted_from_kb, - "message": message, - } + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + + message = f"Confluence page '{page_title}' deleted successfully." + if deleted_from_kb: + message += " Also removed from the knowledge base." + + return { + "status": "success", + "page_id": final_page_id, + "deleted_from_kb": deleted_from_kb, + "message": message, + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py index 791d0d8c5..51c205e00 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_update_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the update_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def update_confluence_page( page_title_or_id: str, @@ -45,164 +63,168 @@ def create_update_confluence_page_tool( f"update_confluence_page called: page_title_or_id='{page_title_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title_or_id - ) + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, page_title_or_id + ) - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "confluence", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + page_data = context["page"] + page_id = page_data["page_id"] + current_title = page_data["page_title"] + current_body = page_data.get("body", "") + current_version = page_data.get("version", 1) + document_id = page_data.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="confluence_page_update", + tool_name="update_confluence_page", + params={ + "page_id": page_id, + "document_id": document_id, + "new_title": new_title, + "new_content": new_content, + "version": current_version, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - page_data = context["page"] - page_id = page_data["page_id"] - current_title = page_data["page_title"] - current_body = page_data.get("body", "") - current_version = page_data.get("version", 1) - document_id = page_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_update", - tool_name="update_confluence_page", - params={ - "page_id": page_id, - "document_id": document_id, - "new_title": new_title, - "new_content": new_content, - "version": current_version, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_title = result.params.get("new_title", new_title) or current_title - final_content = result.params.get("new_content", new_content) - if final_content is None: - final_content = current_body - final_version = result.params.get("version", current_version) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + final_page_id = result.params.get("page_id", page_id) + final_title = result.params.get("new_title", new_title) or current_title + final_content = result.params.get("new_content", new_content) + if final_content is None: + final_content = current_body + final_version = result.params.get("version", current_version) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } + final_document_id = result.params.get("document_id", document_id) - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id - ) - api_result = await client.update_page( - page_id=final_page_id, - title=final_title, - body=final_content, - version_number=final_version + 1, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "No connector found for this page.", } - raise - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - if final_document_id: - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - page_id=final_page_id, - user_id=user_id, - search_space_id=search_space_id, + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } + + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + api_result = await client.update_page( + page_id=final_page_id, + title=final_title, + body=final_content, + version_number=final_version + 1, + ) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + page_links = ( + api_result.get("_links", {}) if isinstance(api_result, dict) else {} + ) + page_url = "" + if page_links.get("base") and page_links.get("webui"): + page_url = f"{page_links['base']}{page_links['webui']}" + + kb_message_suffix = "" + if final_document_id: + try: + from app.services.confluence import ConfluenceKBSyncService + + kb_service = ConfluenceKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + page_id=final_page_id, + user_id=user_id, + search_space_id=search_space_id, ) - else: + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = ( + " The knowledge base will be updated in the next sync." + ) + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") kb_message_suffix = ( " The knowledge base will be updated in the next sync." ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - return { - "status": "success", - "page_id": final_page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", - } + return { + "status": "success", + "page_id": final_page_id, + "page_url": page_url, + "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py index 5675a42e6..6420a90e6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py +++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker from app.services.mcp_oauth.registry import MCP_SERVICES logger = logging.getLogger(__name__) @@ -53,6 +53,23 @@ def create_get_connected_accounts_tool( search_space_id: int, user_id: str, ) -> StructuredTool: + """Factory function to create the get_connected_accounts tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to scope account discovery to. + user_id: User ID to scope account discovery to. + + Returns: + Configured StructuredTool for connected-accounts discovery. + """ + del db_session # per-call session — see docstring async def _run(service: str) -> list[dict[str, Any]]: svc_cfg = MCP_SERVICES.get(service) @@ -68,40 +85,41 @@ def create_get_connected_accounts_tool( except ValueError: return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == connector_type, + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == connector_type, + ) ) - ) - connectors = result.scalars().all() + connectors = result.scalars().all() - if not connectors: - return [ - { - "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + if not connectors: + return [ + { + "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + } + ] + + is_multi = len(connectors) > 1 + + accounts: list[dict[str, Any]] = [] + for conn in connectors: + cfg = conn.config or {} + entry: dict[str, Any] = { + "connector_id": conn.id, + "display_name": _extract_display_name(conn), + "service": service, } - ] + if is_multi: + entry["tool_prefix"] = f"{service}_{conn.id}" + for key in svc_cfg.account_metadata_keys: + if key in cfg: + entry[key] = cfg[key] + accounts.append(entry) - is_multi = len(connectors) > 1 - - accounts: list[dict[str, Any]] = [] - for conn in connectors: - cfg = conn.config or {} - entry: dict[str, Any] = { - "connector_id": conn.id, - "display_name": _extract_display_name(conn), - "service": service, - } - if is_multi: - entry["tool_prefix"] = f"{service}_{conn.id}" - for key in svc_cfg.account_metadata_keys: - if key in cfg: - entry[key] = cfg[key] - accounts.append(entry) - - return accounts + return accounts return StructuredTool( name="get_connected_accounts", diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py index 3cc99ac17..01159a261 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_discord_channels_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_discord_channels tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_discord_channels tool + """ + del db_session # per-call session — see docstring + @tool async def list_discord_channels() -> dict[str, Any]: """List text channels in the connected Discord server. @@ -22,59 +41,60 @@ def create_list_discord_channels_tool( Returns: Dictionary with status and a list of channels (id, name). """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", } try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - guild_id = get_guild_id(connector) - if not guild_id: - return { - "status": "error", - "message": "No guild ID in Discord connector config.", - } - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/guilds/{guild_id}/channels", - headers={"Authorization": f"Bot {token}"}, - timeout=15.0, + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + guild_id = get_guild_id(connector) + if not guild_id: + return { + "status": "error", + "message": "No guild ID in Discord connector config.", + } - # Type 0 = text channel - channels = [ - {"id": ch["id"], "name": ch["name"]} - for ch in resp.json() - if ch.get("type") == 0 - ] - return { - "status": "success", - "guild_id": guild_id, - "channels": channels, - "total": len(channels), - } + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/guilds/{guild_id}/channels", + headers={"Authorization": f"Bot {token}"}, + timeout=15.0, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + # Type 0 = text channel + channels = [ + {"id": ch["id"], "name": ch["name"]} + for ch in resp.json() + if ch.get("type") == 0 + ] + return { + "status": "success", + "guild_id": guild_id, + "channels": channels, + "total": len(channels), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py index d8bf989a1..88d6cdd49 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import DISCORD_API, get_bot_token, get_discord_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_discord_messages_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_discord_messages tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_discord_messages tool + """ + del db_session # per-call session — see docstring + @tool async def read_discord_messages( channel_id: str, @@ -30,7 +49,7 @@ def create_read_discord_messages_tool( Dictionary with status and a list of messages including id, author, content, timestamp. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", @@ -39,55 +58,56 @@ def create_read_discord_messages_tool( limit = min(limit, 50) try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/channels/{channel_id}/messages", - headers={"Authorization": f"Bot {token}"}, - params={"limit": limit}, - timeout=15.0, + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + token = get_bot_token(connector) - messages = [ - { - "id": m["id"], - "author": m.get("author", {}).get("username", "Unknown"), - "content": m.get("content", ""), - "timestamp": m.get("timestamp", ""), - } - for m in resp.json() - ] + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/channels/{channel_id}/messages", + headers={"Authorization": f"Bot {token}"}, + params={"limit": limit}, + timeout=15.0, + ) - return { - "status": "success", - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Bot lacks permission to read this channel.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + messages = [ + { + "id": m["id"], + "author": m.get("author", {}).get("username", "Unknown"), + "content": m.get("content", ""), + "timestamp": m.get("timestamp", ""), + } + for m in resp.json() + ] + + return { + "status": "success", + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py index 236cd017a..5fe6fde35 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import DISCORD_API, get_bot_token, get_discord_connector @@ -17,6 +18,23 @@ def create_send_discord_message_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_discord_message tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_discord_message tool + """ + del db_session # per-call session — see docstring + @tool async def send_discord_message( channel_id: str, @@ -34,7 +52,7 @@ def create_send_discord_message_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", @@ -47,64 +65,65 @@ def create_send_discord_message_tool( } try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - result = request_approval( - action_type="discord_send_message", - tool_name="send_discord_message", - params={"channel_id": channel_id, "content": content}, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_channel = result.params.get("channel_id", channel_id) - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{DISCORD_API}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bot {token}", - "Content-Type": "application/json", - }, - json={"content": final_content}, - timeout=15.0, + result = request_approval( + action_type="discord_send_message", + tool_name="send_discord_message", + params={"channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to send messages in this channel.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": f"Message sent to channel {final_channel}.", - } + final_content = result.params.get("content", content) + final_channel = result.params.get("channel_id", channel_id) + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{DISCORD_API}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bot {token}", + "Content-Type": "application/json", + }, + json={"content": final_content}, + timeout=15.0, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Bot lacks permission to send messages in this channel.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to channel {final_channel}.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py index 22d8a8a27..7aae034cc 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py @@ -10,7 +10,7 @@ from sqlalchemy.future import select from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -59,6 +59,23 @@ def create_create_dropbox_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_dropbox_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_dropbox_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_dropbox_file( name: str, @@ -82,184 +99,191 @@ def create_create_dropbox_file_tool( f"create_dropbox_file called: name='{name}', file_type='{file_type}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Dropbox tool not properly configured.", } try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Dropbox accounts need re-authentication.", - "connector_type": "dropbox", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = DropboxClient(session=db_session, connector_id=cid) - items, err = await client.list_folder("") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - { - "folder_path": item.get("path_lower", ""), - "name": item["name"], - } - for item in items - if item.get(".tag") == "folder" and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", cid, exc_info=True - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - "supported_types": _SUPPORTED_TYPES, - } - - result = request_approval( - action_type="dropbox_file_creation", - tool_name="create_dropbox_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_path": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_path = result.params.get("parent_folder_path") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_extension(final_name, final_file_type) - - if final_connector_id is not None: + async with async_session_maker() as db_session: result = await db_session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type == SearchSourceConnectorType.DROPBOX_CONNECTOR, ) ) - connector = result.scalars().first() - else: - connector = connectors[0] + connectors = result.scalars().all() - if not connector: - return { - "status": "error", - "message": "Selected Dropbox connector is invalid.", + if not connectors: + return { + "status": "error", + "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", + } + + accounts = [] + for c in connectors: + cfg = c.config or {} + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) + + if all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Dropbox accounts need re-authentication.", + "connector_type": "dropbox", + } + + parent_folders: dict[int, list[dict[str, str]]] = {} + for acc in accounts: + cid = acc["id"] + if acc.get("auth_expired"): + parent_folders[cid] = [] + continue + try: + client = DropboxClient(session=db_session, connector_id=cid) + items, err = await client.list_folder("") + if err: + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) + parent_folders[cid] = [] + else: + parent_folders[cid] = [ + { + "folder_path": item.get("path_lower", ""), + "name": item["name"], + } + for item in items + if item.get(".tag") == "folder" and item.get("name") + ] + except Exception: + logger.warning( + "Error fetching folders for connector %s", + cid, + exc_info=True, + ) + parent_folders[cid] = [] + + context: dict[str, Any] = { + "accounts": accounts, + "parent_folders": parent_folders, + "supported_types": _SUPPORTED_TYPES, } - client = DropboxClient(session=db_session, connector_id=connector.id) - - parent_path = final_parent_folder_path or "" - file_path = ( - f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" - ) - - if final_file_type == "paper": - created = await client.create_paper_doc(file_path, final_content or "") - file_id = created.get("file_id", "") - web_url = created.get("url", "") - else: - docx_bytes = _markdown_to_docx(final_content or "") - created = await client.upload_file( - file_path, docx_bytes, mode="add", autorename=True + result = request_approval( + action_type="dropbox_file_creation", + tool_name="create_dropbox_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_path": None, + }, + context=context, ) - file_id = created.get("id", "") - web_url = "" - logger.info(f"Dropbox file created: id={file_id}, name={final_name}") + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - kb_message_suffix = "" - try: - from app.services.dropbox import DropboxKBSyncService + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_path = result.params.get("parent_folder_path") - kb_service = DropboxKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=file_id, - file_name=final_name, - file_path=file_path, - web_url=web_url, - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + final_name = _ensure_extension(final_name, final_file_type) + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DROPBOX_CONNECTOR, + ) + ) + connector = result.scalars().first() else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + connector = connectors[0] - return { - "status": "success", - "file_id": file_id, - "name": final_name, - "web_url": web_url, - "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", - } + if not connector: + return { + "status": "error", + "message": "Selected Dropbox connector is invalid.", + } + + client = DropboxClient(session=db_session, connector_id=connector.id) + + parent_path = final_parent_folder_path or "" + file_path = ( + f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" + ) + + if final_file_type == "paper": + created = await client.create_paper_doc( + file_path, final_content or "" + ) + file_id = created.get("file_id", "") + web_url = created.get("url", "") + else: + docx_bytes = _markdown_to_docx(final_content or "") + created = await client.upload_file( + file_path, docx_bytes, mode="add", autorename=True + ) + file_id = created.get("id", "") + web_url = "" + + logger.info(f"Dropbox file created: id={file_id}, name={final_name}") + + kb_message_suffix = "" + try: + from app.services.dropbox import DropboxKBSyncService + + kb_service = DropboxKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=file_id, + file_name=final_name, + file_path=file_path, + web_url=web_url, + content=final_content, + connector_id=connector.id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": file_id, + "name": final_name, + "web_url": web_url, + "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py index 12559b57a..0e59e49db 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py @@ -13,6 +13,7 @@ from app.db import ( DocumentType, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) logger = logging.getLogger(__name__) @@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_dropbox_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_dropbox_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_dropbox_file( file_name: str, @@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool( f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Dropbox tool not properly configured.", } try: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.DROPBOX_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: + async with async_session_maker() as db_session: doc_result = await db_session.execute( select(Document) .join( @@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool( and_( Document.search_space_id == search_space_id, Document.document_type == DocumentType.DROPBOX_FILE, - func.lower( - cast( - Document.document_metadata["dropbox_file_name"], - String, - ) - ) - == func.lower(file_name), + func.lower(Document.title) == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool( ) document = doc_result.scalars().first() - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed Dropbox files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_path = meta.get("dropbox_path") - file_id = meta.get("dropbox_file_id") - document_id = document.id - - if not file_path: - return { - "status": "error", - "message": "File path is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, + if not document: + doc_result = await db_session.execute( + select(Document) + .join( + SearchSourceConnector, + Document.connector_id == SearchSourceConnector.id, + ) + .filter( + and_( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.DROPBOX_FILE, + func.lower( + cast( + Document.document_metadata["dropbox_file_name"], + String, + ) + ) + == func.lower(file_name), + SearchSourceConnector.user_id == user_id, + ) + ) + .order_by(Document.updated_at.desc().nullslast()) + .limit(1) ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Dropbox connector not found or access denied.", - } + document = doc_result.scalars().first() - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "dropbox", - } + if not document: + return { + "status": "not_found", + "message": ( + f"File '{file_name}' not found in your indexed Dropbox files. " + "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " + "or (3) the file name is different." + ), + } - context = { - "file": { - "file_id": file_id, - "file_path": file_path, - "name": file_name, - "document_id": document_id, - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } + if not document.connector_id: + return { + "status": "error", + "message": "Document has no associated connector.", + } - result = request_approval( - action_type="dropbox_file_trash", - tool_name="delete_dropbox_file", - params={ - "file_path": file_path, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + meta = document.document_metadata or {} + file_path = meta.get("dropbox_path") + file_id = meta.get("dropbox_file_id") + document_id = document.id - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if not file_path: + return { + "status": "error", + "message": "File path is missing. Please re-index the file.", + } - final_file_path = result.params.get("file_path", file_path) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if final_connector_id != connector.id: - result = await db_session.execute( + conn_result = await db_session.execute( select(SearchSourceConnector).filter( and_( - SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type @@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool( ) ) ) - validated_connector = result.scalars().first() - if not validated_connector: + connector = conn_result.scalars().first() + if not connector: return { "status": "error", - "message": "Selected Dropbox connector is invalid or has been disconnected.", + "message": "Dropbox connector not found or access denied.", } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - logger.info( - f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" - ) + cfg = connector.config or {} + if cfg.get("auth_expired"): + return { + "status": "auth_error", + "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "dropbox", + } - client = DropboxClient(session=db_session, connector_id=actual_connector_id) - await client.delete_file(final_file_path) + context = { + "file": { + "file_id": file_id, + "file_path": file_path, + "name": file_name, + "document_id": document_id, + }, + "account": { + "id": connector.id, + "name": connector.name, + "user_email": cfg.get("user_email"), + }, + } - logger.info(f"Dropbox file deleted: path={final_file_path}") - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": file_id, - "message": f"Successfully deleted '{file_name}' from Dropbox.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File deleted, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + result = request_approval( + action_type="dropbox_file_trash", + tool_name="delete_dropbox_file", + params={ + "file_path": file_path, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_file_path = result.params.get("file_path", file_path) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if final_connector_id != connector.id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + and_( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id + == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DROPBOX_CONNECTOR, + ) + ) + ) + validated_connector = result.scalars().first() + if not validated_connector: + return { + "status": "error", + "message": "Selected Dropbox connector is invalid or has been disconnected.", + } + actual_connector_id = validated_connector.id + else: + actual_connector_id = connector.id + + logger.info( + f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" + ) + + client = DropboxClient( + session=db_session, connector_id=actual_connector_id + ) + await client.delete_file(final_file_path) + + logger.info(f"Dropbox file deleted: path={final_file_path}") + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": file_id, + "message": f"Successfully deleted '{file_name}' from Dropbox.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + doc = doc_result.scalars().first() + if doc: + await db_session.delete(doc) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File deleted, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index 3803fa39c..9e287ac51 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -31,6 +31,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -49,12 +50,16 @@ _PROVIDER_MAP = { } +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + prefix = _resolve_provider_prefix(provider, custom_provider) return f"{prefix}/{model_name}" @@ -146,14 +151,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -175,14 +184,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py new file mode 100644 index 000000000..0ca1191a4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py @@ -0,0 +1,41 @@ +from typing import Any + +from app.db import SearchSourceConnector +from app.services.composio_service import ComposioService + + +def split_recipients(value: str | None) -> list[str]: + if not value: + return [] + return [recipient.strip() for recipient in value.split(",") if recipient.strip()] + + +def unwrap_composio_data(data: Any) -> Any: + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + return data + + +async def execute_composio_gmail_tool( + connector: SearchSourceConnector, + user_id: str, + tool_name: str, + params: dict[str, Any], +) -> tuple[Any, str | None]: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return None, "Composio connected account ID not found for this Gmail connector." + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Gmail error") + + return unwrap_composio_data(result.get("data")), None diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index 0bd044695..c88b48d2d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_create_gmail_draft_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_gmail_draft tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_gmail_draft tool + """ + del db_session # per-call session — see docstring + @tool async def create_gmail_draft( to: str, @@ -57,246 +75,276 @@ def create_create_gmail_draft_tool( """ logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_draft_creation", - tool_name="create_gmail_draft", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Gmail accounts have expired authentication") return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - actual_connector_id = connector.id - logger.info( - f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) + logger.info( + f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" + ) + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, + ) - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", + } - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .create(userId="me", body={"message": {"raw": raw}}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - try: - from sqlalchemy.orm.attributes import flag_modified + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + actual_connector_id = connector.id - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + logger.info( + f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + + created, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_CREATE_EMAIL_DRAFT", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(created, dict): + created = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .create(userId="me", body={"message": {"raw": raw}}) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info(f"Gmail draft created: id={created.get('id')}") + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService + logger.info(f"Gmail draft created: id={created.get('id')}") - kb_service = GmailKBSyncService(db_session) - draft_message = created.get("message", {}) - kb_result = await kb_service.sync_after_create( - message_id=draft_message.get("id", ""), - thread_id=draft_message.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - draft_id=created.get("id"), - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: + kb_message_suffix = "" + try: + from app.services.gmail import GmailKBSyncService + + kb_service = GmailKBSyncService(db_session) + draft_message = created.get("message", {}) + kb_result = await kb_service.sync_after_create( + message_id=draft_message.get("id", ""), + thread_id=draft_message.get("threadId", ""), + subject=final_subject, + sender="me", + date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + body_text=final_body, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + draft_id=created.get("id"), + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "draft_id": created.get("id"), - "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } + return { + "status": "success", + "draft_id": created.get("id"), + "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py index deec1627c..464713591 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -5,7 +5,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -20,6 +20,23 @@ def create_read_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def read_gmail_email(message_id: str) -> dict[str, Any]: """Read the full content of a specific Gmail email by its message ID. @@ -32,60 +49,115 @@ def create_read_gmail_email_tool( Returns: Dictionary with status and the full email content formatted as markdown. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Gmail tool not properly configured."} try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - from app.agents.new_chat.tools.gmail.search_emails import _build_credentials - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - detail, error = await gmail.get_message_details(message_id) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): + connector = result.scalars().first() + if not connector: return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } - return {"status": "error", "message": error} - if not detail: + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _format_gmail_summary, + ) + from app.services.composio_service import ComposioService + + service = ComposioService() + detail, error = await service.get_gmail_message_detail( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if error: + return {"status": "error", "message": error} + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + summary = _format_gmail_summary(detail) + content = ( + f"# {summary['subject']}\n\n" + f"**From:** {summary['from']}\n" + f"**To:** {summary['to']}\n" + f"**Date:** {summary['date']}\n\n" + f"## Message Content\n\n" + f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {summary['message_id']}\n" + f"- **Thread ID:** {summary['thread_id']}\n" + ) + return { + "status": "success", + "message_id": summary["message_id"] or message_id, + "content": content, + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _build_credentials, + ) + + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + detail, error = await gmail.get_message_details(message_id) + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } + return {"status": "error", "message": error} + + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + content = gmail.format_message_to_markdown(detail) + return { - "status": "not_found", - "message": f"Email with ID '{message_id}' not found.", + "status": "success", + "message_id": message_id, + "content": content, } - content = gmail.format_message_to_markdown(detail) - - return {"status": "success", "message_id": message_id, "content": content} - except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index 2e363609e..3ce154c53 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -6,7 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector): from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - raise ValueError("Composio connected account ID not found.") - return build_composio_credentials(cca_id) + raise ValueError("Composio connectors must use Composio tool execution.") from google.oauth2.credentials import Credentials @@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector): ) +def _gmail_headers(message: dict[str, Any]) -> dict[str, str]: + headers = message.get("payload", {}).get("headers", []) + return { + header.get("name", "").lower(): header.get("value", "") + for header in headers + if isinstance(header, dict) + } + + +def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]: + headers = _gmail_headers(message) + return { + "message_id": message.get("id") or message.get("messageId"), + "thread_id": message.get("threadId"), + "subject": message.get("subject") or headers.get("subject", "No Subject"), + "from": message.get("sender") or headers.get("from", "Unknown"), + "to": message.get("to") or headers.get("to", ""), + "date": message.get("messageTimestamp") or headers.get("date", ""), + "snippet": message.get("snippet") or message.get("messageText", "")[:300], + "labels": message.get("labelIds", []), + } + + +async def _search_composio_gmail( + connector: SearchSourceConnector, + user_id: str, + query: str, + max_results: int, +) -> dict[str, Any]: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.services.composio_service import ComposioService + + service = ComposioService() + messages, _next_token, _estimate, error = await service.get_gmail_messages( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=max_results, + ) + if error: + return {"status": "error", "message": error} + + emails = [_format_gmail_summary(message) for message in messages] + return { + "status": "success", + "emails": emails, + "total": len(emails), + "message": "No emails found." if not emails else None, + } + + def create_search_gmail_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the search_gmail tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured search_gmail tool + """ + del db_session # per-call session — see docstring + @tool async def search_gmail( query: str, @@ -90,83 +159,92 @@ def create_search_gmail_tool( Dictionary with status and a list of email summaries including message_id, subject, from, date, snippet. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Gmail tool not properly configured."} max_results = min(max_results, 20) try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - messages_list, error = await gmail.get_messages_list( - max_results=max_results, query=query - ) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): + connector = result.scalars().first() + if not connector: return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } - return {"status": "error", "message": error} - if not messages_list: - return { - "status": "success", - "emails": [], - "total": 0, - "message": "No emails found.", - } + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + return await _search_composio_gmail( + connector, str(user_id), query, max_results + ) - emails = [] - for msg in messages_list: - detail, err = await gmail.get_message_details(msg["id"]) - if err: - continue - headers = { - h["name"].lower(): h["value"] - for h in detail.get("payload", {}).get("headers", []) - } - emails.append( - { - "message_id": detail.get("id"), - "thread_id": detail.get("threadId"), - "subject": headers.get("subject", "No Subject"), - "from": headers.get("from", "Unknown"), - "to": headers.get("to", ""), - "date": headers.get("date", ""), - "snippet": detail.get("snippet", ""), - "labels": detail.get("labelIds", []), - } + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, ) - return {"status": "success", "emails": emails, "total": len(emails)} + messages_list, error = await gmail.get_messages_list( + max_results=max_results, query=query + ) + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } + return {"status": "error", "message": error} + + if not messages_list: + return { + "status": "success", + "emails": [], + "total": 0, + "message": "No emails found.", + } + + emails = [] + for msg in messages_list: + detail, err = await gmail.get_message_details(msg["id"]) + if err: + continue + headers = { + h["name"].lower(): h["value"] + for h in detail.get("payload", {}).get("headers", []) + } + emails.append( + { + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + } + ) + + return {"status": "success", "emails": emails, "total": len(emails)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index c3f0999f4..4d5aa3bcc 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_send_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def send_gmail_email( to: str, @@ -58,247 +76,277 @@ def create_send_gmail_email_tool( """ logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_email_send", - tool_name="send_gmail_email", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Gmail accounts have expired authentication") return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - actual_connector_id = connector.id - logger.info( - f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) + logger.info( + f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" + ) + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, + ) - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", + } - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - sent = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .send(userId="me", body={"raw": raw}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - try: - from sqlalchemy.orm.attributes import flag_modified + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + actual_connector_id = connector.id - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + logger.info( + f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + + sent, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_SEND_EMAIL", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(sent, dict): + sent = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + sent = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .send(userId="me", body={"raw": raw}) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info( - f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" - ) + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService - - kb_service = GmailKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - message_id=sent.get("id", ""), - thread_id=sent.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after send failed: {kb_err}") - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "message_id": sent.get("id"), - "thread_id": sent.get("threadId"), - "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", - } + kb_message_suffix = "" + try: + from app.services.gmail import GmailKBSyncService + + kb_service = GmailKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + message_id=sent.get("id", ""), + thread_id=sent.get("threadId", ""), + subject=final_subject, + sender="me", + date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + body_text=final_body, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after send failed: {kb_err}") + kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "message_id": sent.get("id"), + "thread_id": sent.get("threadId"), + "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index 1f1f6227a..95f5b4e6c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -7,6 +7,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -17,6 +18,23 @@ def create_trash_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the trash_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured trash_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def trash_gmail_email( email_subject_or_id: str, @@ -55,244 +73,261 @@ def create_trash_gmail_email_tool( f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, email_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Email not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_trash_context( + search_space_id, user_id, email_subject_or_id ) - return { - "status": "auth_error", - "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Email not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch trash context: {error_msg}") + return {"status": "error", "message": error_msg} - if not message_id: - return { - "status": "error", - "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Gmail account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", + } - logger.info( - f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="gmail_email_trash", - tool_name="trash_gmail_email", - params={ - "message_id": message_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + email = context["email"] + message_id = email["message_id"] + document_id = email.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", - } - - final_message_id = result.params.get("message_id", message_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this email.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not message_id: return { "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", + "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", } - else: - from google.oauth2.credentials import Credentials - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="gmail_email_trash", + tool_name="trash_gmail_email", + params={ + "message_id": message_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .trash(userId="me", id=final_message_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", } - raise - logger.info(f"Gmail email trashed: message_id={final_message_id}") - - trash_result: dict[str, Any] = { - "status": "success", - "message_id": final_message_id, - "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"Email trashed, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + final_message_id = result.params.get("message_id", message_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb ) - return trash_result + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this email.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + + logger.info( + f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + _trashed, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_MOVE_TO_TRASH", + {"user_id": "me", "message_id": final_message_id}, + ) + if error: + raise RuntimeError(error) + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .trash(userId="me", id=final_message_id) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Gmail email trashed: message_id={final_message_id}") + + trash_result: dict[str, Any] = { + "status": "success", + "message_id": final_message_id, + "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"Email trashed, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 91178cd21..129b7defb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_update_gmail_draft_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the update_gmail_draft tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_gmail_draft tool + """ + del db_session # per-call session — see docstring + @tool async def update_gmail_draft( draft_subject_or_id: str, @@ -76,294 +94,329 @@ def create_update_gmail_draft_tool( f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, draft_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Draft not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = account["id"] - draft_id_from_context = context.get("draft_id") - - original_subject = email.get("subject", draft_subject_or_id) - final_subject_default = subject if subject else original_subject - final_to_default = to if to else "" - - logger.info( - f"Requesting approval for updating Gmail draft: '{original_subject}' " - f"(message_id={message_id}, draft_id={draft_id_from_context})" - ) - result = request_approval( - action_type="gmail_draft_update", - tool_name="update_gmail_draft", - params={ - "message_id": message_id, - "draft_id": draft_id_from_context, - "to": final_to_default, - "subject": final_subject_default, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", final_to_default) - final_subject = result.params.get("subject", final_subject_default) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_draft_id = result.params.get("draft_id", draft_id_from_context) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this draft.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, draft_subject_or_id ) - from googleapiclient.discovery import build + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Draft not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - gmail_service = build("gmail", "v1", credentials=creds) - - # Resolve draft_id if not already available - if not final_draft_id: - logger.info( - f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" - ) - final_draft_id = await _find_draft_id_by_message( - gmail_service, message_id - ) - - if not final_draft_id: - return { - "status": "error", - "message": ( - "Could not find this draft in Gmail. " - "It may have already been sent or deleted." - ), - } - - message = MIMEText(final_body) - if final_to: - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .update( - userId="me", - id=final_draft_id, - body={"message": {"raw": raw}}, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" + "Gmail account %s has expired authentication", + account.get("id"), ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - if isinstance(api_err, HttpError) and api_err.resp.status == 404: + + email = context["email"] + message_id = email["message_id"] + document_id = email.get("document_id") + connector_id_from_context = account["id"] + draft_id_from_context = context.get("draft_id") + + original_subject = email.get("subject", draft_subject_or_id) + final_subject_default = subject if subject else original_subject + final_to_default = to if to else "" + + logger.info( + f"Requesting approval for updating Gmail draft: '{original_subject}' " + f"(message_id={message_id}, draft_id={draft_id_from_context})" + ) + result = request_approval( + action_type="gmail_draft_update", + tool_name="update_gmail_draft", + params={ + "message_id": message_id, + "draft_id": draft_id_from_context, + "to": final_to_default, + "subject": final_subject_default, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", + } + + final_to = result.params.get("to", final_to_default) + final_subject = result.params.get("subject", final_subject_default) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_draft_id = result.params.get("draft_id", draft_id_from_context) + + if not final_connector_id: return { "status": "error", - "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + "message": "No connector found for this draft.", } - raise - logger.info(f"Gmail draft updated: id={updated.get('id')}") + from sqlalchemy.future import select - kb_message_suffix = "" - if document_id: - try: - from sqlalchemy.future import select as sa_select - from sqlalchemy.orm.attributes import flag_modified + from app.db import SearchSourceConnector, SearchSourceConnectorType - from app.db import Document + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] - doc_result = await db_session.execute( - sa_select(Document).filter(Document.id == document_id) + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - document = doc_result.scalars().first() - if document: - document.source_markdown = final_body - document.title = final_subject - meta = dict(document.document_metadata or {}) - meta["subject"] = final_subject - meta["draft_id"] = updated.get("id", final_draft_id) - updated_msg = updated.get("message", {}) - if updated_msg.get("id"): - meta["message_id"] = updated_msg["id"] - document.document_metadata = meta - flag_modified(document, "document_metadata") - await db_session.commit() - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - logger.info( - f"KB document {document_id} updated for draft {final_draft_id}" + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + + logger.info( + f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + # Resolve draft_id if not already available + if not final_draft_id: + logger.info( + f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" + ) + if is_composio_gmail: + final_draft_id = await _find_composio_draft_id_by_message( + connector, user_id, message_id ) else: - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB update after draft edit failed: {kb_err}") - await db_session.rollback() - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + from googleapiclient.discovery import build - return { - "status": "success", - "draft_id": updated.get("id"), - "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } + gmail_service = build("gmail", "v1", credentials=creds) + final_draft_id = await _find_draft_id_by_message( + gmail_service, message_id + ) + + if not final_draft_id: + return { + "status": "error", + "message": ( + "Could not find this draft in Gmail. " + "It may have already been sent or deleted." + ), + } + + message = MIMEText(final_body) + if final_to: + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, + ) + + updated, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_UPDATE_DRAFT", + { + "user_id": "me", + "draft_id": final_draft_id, + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(updated, dict): + updated = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .update( + userId="me", + id=final_draft_id, + body={"message": {"raw": raw}}, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + if isinstance(api_err, HttpError) and api_err.resp.status == 404: + return { + "status": "error", + "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + } + raise + + logger.info(f"Gmail draft updated: id={updated.get('id')}") + + kb_message_suffix = "" + if document_id: + try: + from sqlalchemy.future import select as sa_select + from sqlalchemy.orm.attributes import flag_modified + + from app.db import Document + + doc_result = await db_session.execute( + sa_select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + document.source_markdown = final_body + document.title = final_subject + meta = dict(document.document_metadata or {}) + meta["subject"] = final_subject + meta["draft_id"] = updated.get("id", final_draft_id) + updated_msg = updated.get("message", {}) + if updated_msg.get("id"): + meta["message_id"] = updated_msg["id"] + document.document_metadata = meta + flag_modified(document, "document_metadata") + await db_session.commit() + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + logger.info( + f"KB document {document_id} updated for draft {final_draft_id}" + ) + else: + kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB update after draft edit failed: {kb_err}") + await db_session.rollback() + kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + + return { + "status": "success", + "draft_id": updated.get("id"), + "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt @@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str except Exception as e: logger.warning(f"Failed to look up draft by message_id: {e}") return None + + +async def _find_composio_draft_id_by_message( + connector: Any, user_id: str, message_id: str +) -> str | None: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + page_token = "" + while True: + params: dict[str, Any] = { + "user_id": "me", + "max_results": 100, + "verbose": False, + } + if page_token: + params["page_token"] = page_token + + data, error = await execute_composio_gmail_tool( + connector, user_id, "GMAIL_LIST_DRAFTS", params + ) + if error or not isinstance(data, dict): + return None + + for draft in data.get("drafts", []): + if draft.get("message", {}).get("id") == message_id: + return draft.get("id") + + page_token = data.get("nextPageToken") or data.get("next_page_token") or "" + if not page_token: + return None diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py index 37bcf083e..dec92cc8b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_create_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def create_calendar_event( summary: str, @@ -60,254 +78,294 @@ def create_create_calendar_event_tool( f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning( - "All Google Calendar accounts have expired authentication" + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - return { - "status": "auth_error", - "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - logger.info( - f"Requesting approval for creating calendar event: summary='{summary}'" - ) - result = request_approval( - action_type="google_calendar_event_creation", - tool_name="create_calendar_event", - params={ - "summary": summary, - "start_datetime": start_datetime, - "end_datetime": end_datetime, - "description": description, - "location": location, - "attendees": attendees, - "timezone": context.get("timezone"), - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", - } - - final_summary = result.params.get("summary", summary) - final_start_datetime = result.params.get("start_datetime", start_datetime) - final_end_datetime = result.params.get("end_datetime", end_datetime) - final_description = result.params.get("description", description) - final_location = result.params.get("location", location) - final_attendees = result.params.get("attendees", attendees) - final_connector_id = result.params.get("connector_id") - - if not final_summary or not final_summary.strip(): - return {"status": "error", "message": "Event summary cannot be empty."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning( + "All Google Calendar accounts have expired authentication" ) + return { + "status": "auth_error", + "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } + + logger.info( + f"Requesting approval for creating calendar event: summary='{summary}'" ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - else: - config_data = dict(connector.config) - - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + result = request_approval( + action_type="google_calendar_event_creation", + tool_name="create_calendar_event", + params={ + "summary": summary, + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "description": description, + "location": location, + "attendees": attendees, + "timezone": context.get("timezone"), + "connector_id": None, + }, + context=context, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", + } - tz = context.get("timezone", "UTC") - event_body: dict[str, Any] = { - "summary": final_summary, - "start": {"dateTime": final_start_datetime, "timeZone": tz}, - "end": {"dateTime": final_end_datetime, "timeZone": tz}, - } - if final_description: - event_body["description"] = final_description - if final_location: - event_body["location"] = final_location - if final_attendees: - event_body["attendees"] = [ - {"email": e.strip()} for e in final_attendees if e.strip() + final_summary = result.params.get("summary", summary) + final_start_datetime = result.params.get( + "start_datetime", start_datetime + ) + final_end_datetime = result.params.get("end_datetime", end_datetime) + final_description = result.params.get("description", description) + final_location = result.params.get("location", location) + final_attendees = result.params.get("attendees", attendees) + final_connector_id = result.params.get("connector_id") + + if not final_summary or not final_summary.strip(): + return { + "status": "error", + "message": "Event summary cannot be empty.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, ] - try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .insert(calendarId="primary", body=event_body) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_calendar import GoogleCalendarKBSyncService - - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - event_id=created.get("id"), - event_summary=final_summary, - calendar_id="primary", - start_time=final_start_datetime, - end_time=final_end_datetime, - location=final_location, - html_link=created.get("htmlLink"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } + actual_connector_id = connector.id - return { - "status": "success", - "event_id": created.get("id"), - "html_link": created.get("htmlLink"), - "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", - } + logger.info( + f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" + ) + + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + tz = context.get("timezone", "UTC") + event_body: dict[str, Any] = { + "summary": final_summary, + "start": {"dateTime": final_start_datetime, "timeZone": tz}, + "end": {"dateTime": final_end_datetime, "timeZone": tz}, + } + if final_description: + event_body["description"] = final_description + if final_location: + event_body["location"] = final_location + if final_attendees: + event_body["attendees"] = [ + {"email": e.strip()} for e in final_attendees if e.strip() + ] + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params = { + "calendar_id": "primary", + "summary": final_summary, + "start_datetime": final_start_datetime, + "end_datetime": final_end_datetime, + "timezone": tz, + "attendees": final_attendees or [], + } + if final_description: + composio_params["description"] = final_description + if final_location: + composio_params["location"] = final_location + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_CREATE_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + created = composio_result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .insert(calendarId="primary", body=event_body) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" + ) + + kb_message_suffix = "" + try: + from app.services.google_calendar import GoogleCalendarKBSyncService + + kb_service = GoogleCalendarKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + event_id=created.get("id"), + event_summary=final_summary, + calendar_id="primary", + start_time=final_start_datetime, + end_time=final_end_datetime, + location=final_location, + html_link=created.get("htmlLink"), + description=final_description, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "event_id": created.get("id"), + "html_link": created.get("htmlLink"), + "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py index 4d9d69b4b..e7e891b08 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_delete_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def delete_calendar_event( event_title_or_id: str, @@ -54,240 +72,258 @@ def create_delete_calendar_event_tool( f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch deletion context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Calendar account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, event_title_or_id ) - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Event not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch deletion context: {error_msg}") + return {"status": "error", "message": error_msg} - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Google Calendar account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } - logger.info( - f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_calendar_event_deletion", - tool_name="delete_calendar_event", - params={ - "event_id": event_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + event = context["event"] + event_id = event["event_id"] + document_id = event.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - - actual_connector_id = connector.id - - logger.info( - f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not event_id: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", } - else: - config_data = dict(connector.config) - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="google_calendar_event_deletion", + tool_name="delete_calendar_event", + params={ + "event_id": event_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .delete(calendarId="primary", eventId=final_event_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", } - raise - logger.info(f"Calendar event deleted: event_id={final_event_id}") - - delete_result: dict[str, Any] = { - "status": "success", - "event_id": final_event_id, - "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - delete_result["warning"] = ( - f"Event deleted, but failed to remove from knowledge base: {e!s}" - ) - - delete_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - delete_result["message"] = ( - f"{delete_result.get('message', '')} (also removed from knowledge base)" + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb ) - return delete_result + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this event.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", + } + + actual_connector_id = connector.id + + logger.info( + f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" + ) + + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_DELETE_EVENT", + params={ + "calendar_id": "primary", + "event_id": final_event_id, + }, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .delete(calendarId="primary", eventId=final_event_id) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Calendar event deleted: event_id={final_event_id}") + + delete_result: dict[str, Any] = { + "status": "success", + "event_id": final_event_id, + "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + delete_result["warning"] = ( + f"Event deleted, but failed to remove from knowledge base: {e!s}" + ) + + delete_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + delete_result["message"] = ( + f"{delete_result.get('message', '')} (also removed from knowledge base)" + ) + + return delete_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index dc6adb822..e5f18f675 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.agents.new_chat.tools.gmail.search_emails import _build_credentials -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -16,11 +16,57 @@ _CALENDAR_TYPES = [ ] +def _to_calendar_boundary(value: str, *, is_end: bool) -> str: + if "T" in value: + return value + time = "23:59:59" if is_end else "00:00:00" + return f"{value}T{time}Z" + + +def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + events = [] + for ev in events_raw: + start = ev.get("start", {}) + end = ev.get("end", {}) + attendees_raw = ev.get("attendees", []) + events.append( + { + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [a.get("email", "") for a in attendees_raw[:10]], + "status": ev.get("status", ""), + } + ) + return events + + def create_search_calendar_events_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the search_calendar_events tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured search_calendar_events tool + """ + del db_session # per-call session — see docstring + @tool async def search_calendar_events( start_date: str, @@ -38,7 +84,7 @@ def create_search_calendar_events_tool( Dictionary with status and a list of events including event_id, summary, start, end, location, attendees. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Calendar tool not properly configured.", @@ -47,76 +93,85 @@ def create_search_calendar_events_tool( max_results = min(max_results, 50) try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } - creds = _build_credentials(connector) - - from app.connectors.google_calendar_connector import GoogleCalendarConnector - - cal = GoogleCalendarConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - events_raw, error = await cal.get_all_primary_calendar_events( - start_date=start_date, - end_date=end_date, - max_results=max_results, - ) - - if error: if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - return { - "status": "auth_error", - "message": error, - "connector_type": "google_calendar", - } - if "no events found" in error.lower(): - return { - "status": "success", - "events": [], - "total": 0, - "message": error, - } - return {"status": "error", "message": error} + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } - events = [] - for ev in events_raw: - start = ev.get("start", {}) - end = ev.get("end", {}) - attendees_raw = ev.get("attendees", []) - events.append( - { - "event_id": ev.get("id"), - "summary": ev.get("summary", "No Title"), - "start": start.get("dateTime") or start.get("date", ""), - "end": end.get("dateTime") or end.get("date", ""), - "location": ev.get("location", ""), - "description": ev.get("description", ""), - "html_link": ev.get("htmlLink", ""), - "attendees": [a.get("email", "") for a in attendees_raw[:10]], - "status": ev.get("status", ""), - } - ) + from app.services.composio_service import ComposioService - return {"status": "success", "events": events, "total": len(events)} + events_raw, error = await ComposioService().get_calendar_events( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + time_min=_to_calendar_boundary(start_date, is_end=False), + time_max=_to_calendar_boundary(end_date, is_end=True), + max_results=max_results, + ) + if not events_raw and not error: + error = "No events found in the specified date range." + else: + creds = _build_credentials(connector) + + from app.connectors.google_calendar_connector import ( + GoogleCalendarConnector, + ) + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) + + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "google_calendar", + } + if "no events found" in error.lower(): + return { + "status": "success", + "events": [], + "total": 0, + "message": error, + } + return {"status": "error", "message": error} + + events = _format_calendar_events(events_raw) + + return {"status": "success", "events": events, "total": len(events)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 259f52bba..b8561fee6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -33,6 +34,23 @@ def create_update_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the update_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def update_calendar_event( event_title_or_id: str, @@ -74,272 +92,317 @@ def create_update_calendar_event_tool( """ logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - if context.get("auth_expired"): - logger.warning("Google Calendar account has expired authentication") - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } - - logger.info( - f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" - ) - result = request_approval( - action_type="google_calendar_event_update", - tool_name="update_calendar_event", - params={ - "event_id": event_id, - "document_id": document_id, - "connector_id": connector_id_from_context, - "new_summary": new_summary, - "new_start_datetime": new_start_datetime, - "new_end_datetime": new_end_datetime, - "new_description": new_description, - "new_location": new_location, - "new_attendees": new_attendees, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_new_summary = result.params.get("new_summary", new_summary) - final_new_start_datetime = result.params.get( - "new_start_datetime", new_start_datetime - ) - final_new_end_datetime = result.params.get( - "new_end_datetime", new_end_datetime - ) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_location = result.params.get("new_location", new_location) - final_new_attendees = result.params.get("new_attendees", new_attendees) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, event_title_or_id ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Event not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - logger.info( - f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) + if context.get("auth_expired"): + logger.warning("Google Calendar account has expired authentication") + return { + "status": "auth_error", + "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials + event = context["event"] + event_id = event["event_id"] + document_id = event.get("document_id") + connector_id_from_context = context["account"]["id"] - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not event_id: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", } - else: - config_data = dict(connector.config) - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" + ) + result = request_approval( + action_type="google_calendar_event_update", + tool_name="update_calendar_event", + params={ + "event_id": event_id, + "document_id": document_id, + "connector_id": connector_id_from_context, + "new_summary": new_summary, + "new_start_datetime": new_start_datetime, + "new_end_datetime": new_end_datetime, + "new_description": new_description, + "new_location": new_location, + "new_attendees": new_attendees, + }, + context=context, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", + } - update_body: dict[str, Any] = {} - if final_new_summary is not None: - update_body["summary"] = final_new_summary - if final_new_start_datetime is not None: - update_body["start"] = _build_time_body( - final_new_start_datetime, context + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context ) - if final_new_end_datetime is not None: - update_body["end"] = _build_time_body(final_new_end_datetime, context) - if final_new_description is not None: - update_body["description"] = final_new_description - if final_new_location is not None: - update_body["location"] = final_new_location - if final_new_attendees is not None: - update_body["attendees"] = [ - {"email": e.strip()} for e in final_new_attendees if e.strip() + final_new_summary = result.params.get("new_summary", new_summary) + final_new_start_datetime = result.params.get( + "new_start_datetime", new_start_datetime + ) + final_new_end_datetime = result.params.get( + "new_end_datetime", new_end_datetime + ) + final_new_description = result.params.get( + "new_description", new_description + ) + final_new_location = result.params.get("new_location", new_location) + final_new_attendees = result.params.get("new_attendees", new_attendees) + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this event.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, ] - if not update_body: - return { - "status": "error", - "message": "No changes specified. Please provide at least one field to update.", - } - - try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .patch( - calendarId="primary", - eventId=final_event_id, - body=update_body, - ) - .execute() - ), + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) + connector = result.scalars().first() + if not connector: return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", } - raise - logger.info(f"Calendar event updated: event_id={final_event_id}") + actual_connector_id = connector.id - kb_message_suffix = "" - if document_id is not None: - try: - from app.services.google_calendar import GoogleCalendarKBSyncService + logger.info( + f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" + ) - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - event_id=final_event_id, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - return { - "status": "success", - "event_id": final_event_id, - "html_link": updated.get("htmlLink"), - "message": f"Successfully updated the calendar event.{kb_message_suffix}", - } + update_body: dict[str, Any] = {} + if final_new_summary is not None: + update_body["summary"] = final_new_summary + if final_new_start_datetime is not None: + update_body["start"] = _build_time_body( + final_new_start_datetime, context + ) + if final_new_end_datetime is not None: + update_body["end"] = _build_time_body( + final_new_end_datetime, context + ) + if final_new_description is not None: + update_body["description"] = final_new_description + if final_new_location is not None: + update_body["location"] = final_new_location + if final_new_attendees is not None: + update_body["attendees"] = [ + {"email": e.strip()} for e in final_new_attendees if e.strip() + ] + + if not update_body: + return { + "status": "error", + "message": "No changes specified. Please provide at least one field to update.", + } + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params: dict[str, Any] = { + "calendar_id": "primary", + "event_id": final_event_id, + } + if final_new_summary is not None: + composio_params["summary"] = final_new_summary + if final_new_start_datetime is not None: + composio_params["start_time"] = final_new_start_datetime + if final_new_end_datetime is not None: + composio_params["end_time"] = final_new_end_datetime + if final_new_description is not None: + composio_params["description"] = final_new_description + if final_new_location is not None: + composio_params["location"] = final_new_location + if final_new_attendees is not None: + composio_params["attendees"] = [ + e.strip() for e in final_new_attendees if e.strip() + ] + if not _is_date_only( + final_new_start_datetime or final_new_end_datetime or "" + ): + composio_params["timezone"] = context.get("timezone", "UTC") + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_PATCH_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + updated = composio_result.get("data", {}) + if isinstance(updated, dict): + updated = updated.get("data", updated) + if isinstance(updated, dict): + updated = updated.get("response_data", updated) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .patch( + calendarId="primary", + eventId=final_event_id, + body=update_body, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Calendar event updated: event_id={final_event_id}") + + kb_message_suffix = "" + if document_id is not None: + try: + from app.services.google_calendar import ( + GoogleCalendarKBSyncService, + ) + + kb_service = GoogleCalendarKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=document_id, + event_id=final_event_id, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") + kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." + + return { + "status": "success", + "event_id": final_event_id, + "html_link": updated.get("htmlLink"), + "message": f"Successfully updated the calendar event.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py index f36db8f3f..66199ca67 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET +from app.db import async_session_maker from app.services.google_drive import GoogleDriveToolMetadataService logger = logging.getLogger(__name__) @@ -23,6 +24,25 @@ def create_create_google_drive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_google_drive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Google Drive connector + user_id: User ID for fetching user-specific context + + Returns: + Configured create_google_drive_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_google_drive_file( name: str, @@ -65,7 +85,7 @@ def create_create_google_drive_file_tool( f"create_google_drive_file called: name='{name}', type='{file_type}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Drive tool not properly configured. Please contact support.", @@ -78,195 +98,232 @@ def create_create_google_drive_file_tool( } try: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = GoogleDriveToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Google Drive accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - - logger.info( - f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" - ) - result = request_approval( - action_type="google_drive_file_creation", - tool_name="create_google_drive_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - mime_type = _MIME_MAP.get(final_file_type) - if not mime_type: - return { - "status": "error", - "message": f"Unsupported file type '{final_file_type}'.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", - } - actual_connector_id = connector.id + return {"status": "error", "message": context["error"]} - logger.info( - f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" - ) - - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) - - client = GoogleDriveClient( - session=db_session, - connector_id=actual_connector_id, - credentials=pre_built_creds, - ) - try: - created = await client.create_file( - name=final_name, - mime_type=mime_type, - parent_folder_id=final_parent_folder_id, - content=final_content, - ) - except HttpError as http_err: - if http_err.resp.status == 403: + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + "All Google Drive accounts have expired authentication" ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_drive", } - raise - logger.info( - f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_drive import GoogleDriveKBSyncService - - kb_service = GoogleDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=mime_type, - web_view_link=created.get("webViewLink"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" + ) + result = request_approval( + action_type="google_drive_file_creation", + tool_name="create_google_drive_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_view_link": created.get("webViewLink"), - "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", + } + + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") + + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + mime_type = _MIME_MAP.get(final_file_type) + if not mime_type: + return { + "status": "error", + "message": f"Unsupported file type '{final_file_type}'.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _drive_types = [ + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Drive connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", + } + actual_connector_id = connector.id + + logger.info( + f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" + ) + + is_composio_drive = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + if is_composio_drive: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } + client = GoogleDriveClient( + session=db_session, + connector_id=actual_connector_id, + ) + try: + if is_composio_drive: + from app.services.composio_service import ComposioService + + params: dict[str, Any] = { + "name": final_name, + "mimeType": mime_type, + "fields": "id,name,webViewLink,mimeType", + } + if final_parent_folder_id: + params["parents"] = [final_parent_folder_id] + if final_content: + params["description"] = final_content[:4096] + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_CREATE_FILE", + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + created = result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + if not isinstance(created, dict): + created = {} + else: + created = await client.create_file( + name=final_name, + mime_type=mime_type, + parent_folder_id=final_parent_folder_id, + content=final_content, + ) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" + ) + + kb_message_suffix = "" + try: + from app.services.google_drive import GoogleDriveKBSyncService + + kb_service = GoogleDriveKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=created.get("id"), + file_name=created.get("name", final_name), + mime_type=mime_type, + web_view_link=created.get("webViewLink"), + content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": created.get("id"), + "name": created.get("name"), + "web_view_link": created.get("webViewLink"), + "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py index 832afff0d..b3c9240d8 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient +from app.db import async_session_maker from app.services.google_drive import GoogleDriveToolMetadataService logger = logging.getLogger(__name__) @@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_google_drive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Google Drive connector + user_id: User ID for fetching user-specific context + + Returns: + Configured delete_google_drive_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_google_drive_file( file_name: str, @@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool( f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Drive tool not properly configured. Please contact support.", } try: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, file_name - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"File not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Drive account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GoogleDriveToolMetadataService(db_session) + context = await metadata_service.get_trash_context( + search_space_id, user_id, file_name ) - return { - "status": "auth_error", - "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - file = context["file"] - file_id = file["file_id"] - document_id = file.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"File not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch trash context: {error_msg}") + return {"status": "error", "message": error_msg} - if not file_id: - return { - "status": "error", - "message": "File ID is missing from the indexed document. Please re-index the file and try again.", - } - - logger.info( - f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_drive_file_trash", - tool_name="delete_google_drive_file", - params={ - "file_id": file_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", - } - - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this file.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - - logger.info( - f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" - ) - - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) - - client = GoogleDriveClient( - session=db_session, - connector_id=connector.id, - credentials=pre_built_creds, - ) - try: - await client.trash_file(file_id=final_file_id) - except HttpError as http_err: - if http_err.resp.status == 403: + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"Insufficient permissions for connector {connector.id}: {http_err}" + "Google Drive account %s has expired authentication", + account.get("id"), ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_drive", } - raise - logger.info( - f"Google Drive file deleted (moved to trash): file_id={final_file_id}" - ) + file = context["file"] + file_id = file["file_id"] + document_id = file.get("document_id") + connector_id_from_context = context["account"]["id"] - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file['name']}' to trash.", - } + if not file_id: + return { + "status": "error", + "message": "File ID is missing from the indexed document. Please re-index the file and try again.", + } - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to trash, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + logger.info( + f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="google_drive_file_trash", + tool_name="delete_google_drive_file", + params={ + "file_id": file_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", + } + + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this file.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _drive_types = [ + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Drive connector is invalid or has been disconnected.", + } + + logger.info( + f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" + ) + + is_composio_drive = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + if is_composio_drive: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } + + client = GoogleDriveClient( + session=db_session, + connector_id=connector.id, + ) + try: + if is_composio_drive: + from app.services.composio_service import ComposioService + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_TRASH_FILE", + params={"file_id": final_file_id}, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + else: + await client.trash_file(file_id=final_file_id) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {http_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Google Drive file deleted (moved to trash): file_id={final_file_id}" + ) + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": final_file_id, + "message": f"Successfully moved '{file['name']}' to trash.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File moved to trash, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 92248c2c9..5b64929de 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( { "create_gmail_draft", "update_gmail_draft", + "create_calendar_event", "create_notion_page", "create_confluence_page", "create_google_drive_file", diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py index 8b40dde65..0b04f1642 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,28 @@ def create_create_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the create_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Per-call sessions also + keep the request's outer transaction free of long-running Jira API + blocking. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured create_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def create_jira_issue( project_key: str, @@ -49,158 +72,167 @@ def create_create_jira_issue_tool( f"create_jira_issue called: project_key='{project_key}', summary='{summary}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Jira accounts need re-authentication.", - "connector_type": "jira", - } - - result = request_approval( - action_type="jira_issue_creation", - tool_name="create_jira_issue", - params={ - "project_key": project_key, - "summary": summary, - "issue_type": issue_type, - "description": description, - "priority": priority, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_project_key = result.params.get("project_key", project_key) - final_summary = result.params.get("summary", summary) - final_issue_type = result.params.get("issue_type", issue_type) - final_description = result.params.get("description", description) - final_priority = result.params.get("priority", priority) - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_summary or not final_summary.strip(): - return {"status": "error", "message": "Issue summary cannot be empty."} - if not final_project_key: - return {"status": "error", "message": "A project must be selected."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return {"status": "error", "message": "No Jira connector found."} - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, - ) + + if "error" in context: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Jira accounts need re-authentication.", + "connector_type": "jira", + } + + result = request_approval( + action_type="jira_issue_creation", + tool_name="create_jira_issue", + params={ + "project_key": project_key, + "summary": summary, + "issue_type": issue_type, + "description": description, + "priority": priority, + "connector_id": connector_id, + }, + context=context, ) - connector = result.scalars().first() - if not connector: + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_project_key = result.params.get("project_key", project_key) + final_summary = result.params.get("summary", summary) + final_issue_type = result.params.get("issue_type", issue_type) + final_description = result.params.get("description", description) + final_priority = result.params.get("priority", priority) + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_summary or not final_summary.strip(): return { "status": "error", - "message": "Selected Jira connector is invalid.", + "message": "Issue summary cannot be empty.", } + if not final_project_key: + return {"status": "error", "message": "A project must be selected."} - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - jira_client = await jira_history._get_jira_client() - api_result = await asyncio.to_thread( - jira_client.create_issue, - project_key=final_project_key, - summary=final_summary, - issue_type=final_issue_type, - description=final_description, - priority=final_priority, - ) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + from sqlalchemy.future import select - issue_key = api_result.get("key", "") - issue_url = ( - f"{jira_history._base_url}/browse/{issue_key}" - if jira_history._base_url and issue_key - else "" - ) + from app.db import SearchSourceConnector, SearchSourceConnectorType - kb_message_suffix = "" - try: - from app.services.jira import JiraKBSyncService - - kb_service = JiraKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=issue_key, - issue_identifier=issue_key, - issue_title=final_summary, - description=final_description, - state="To Do", - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Jira connector found.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } - return { - "status": "success", - "issue_key": issue_key, - "issue_url": issue_url, - "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", - } + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=actual_connector_id + ) + jira_client = await jira_history._get_jira_client() + api_result = await asyncio.to_thread( + jira_client.create_issue, + project_key=final_project_key, + summary=final_summary, + issue_type=final_issue_type, + description=final_description, + priority=final_priority, + ) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + _conn = connector + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + issue_key = api_result.get("key", "") + issue_url = ( + f"{jira_history._base_url}/browse/{issue_key}" + if jira_history._base_url and issue_key + else "" + ) + + kb_message_suffix = "" + try: + from app.services.jira import JiraKBSyncService + + kb_service = JiraKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + issue_id=issue_key, + issue_identifier=issue_key, + issue_title=final_summary, + description=final_description, + state="To Do", + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "issue_key": issue_key, + "issue_url": issue_url, + "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py index 6466c80ea..c41aedad9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,26 @@ def create_delete_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the delete_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured delete_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def delete_jira_issue( issue_title_or_key: str, @@ -44,130 +65,136 @@ def create_delete_jira_issue_tool( f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, issue_title_or_key - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "jira", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - issue_data = context["issue"] - issue_key = issue_data["issue_id"] - document_id = issue_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="jira_issue_deletion", - tool_name="delete_jira_issue", - params={ - "issue_key": issue_key, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_key = result.params.get("issue_key", issue_key) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this issue.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, issue_title_or_key ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Jira connector is invalid.", - } - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=final_connector_id + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "jira", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + issue_data = context["issue"] + issue_key = issue_data["issue_id"] + document_id = issue_data["document_id"] + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="jira_issue_deletion", + tool_name="delete_jira_issue", + params={ + "issue_key": issue_key, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - jira_client = await jira_history._get_jira_client() - await asyncio.to_thread(jira_client.delete_issue, final_issue_key) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document + final_issue_key = result.params.get("issue_key", issue_key) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this issue.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } - message = f"Jira issue {final_issue_key} deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + jira_client = await jira_history._get_jira_client() + await asyncio.to_thread(jira_client.delete_issue, final_issue_key) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - return { - "status": "success", - "issue_key": final_issue_key, - "deleted_from_kb": deleted_from_kb, - "message": message, - } + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + + message = f"Jira issue {final_issue_key} deleted successfully." + if deleted_from_kb: + message += " Also removed from the knowledge base." + + return { + "status": "success", + "issue_key": final_issue_key, + "deleted_from_kb": deleted_from_kb, + "message": message, + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py index f6e586a2e..0fd7b28b3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,26 @@ def create_update_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the update_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured update_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def update_jira_issue( issue_title_or_key: str, @@ -48,169 +69,177 @@ def create_update_jira_issue_tool( f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_title_or_key - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "jira", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - issue_data = context["issue"] - issue_key = issue_data["issue_id"] - document_id = issue_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="jira_issue_update", - tool_name="update_jira_issue", - params={ - "issue_key": issue_key, - "document_id": document_id, - "new_summary": new_summary, - "new_description": new_description, - "new_priority": new_priority, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_key = result.params.get("issue_key", issue_key) - final_summary = result.params.get("new_summary", new_summary) - final_description = result.params.get("new_description", new_description) - final_priority = result.params.get("new_priority", new_priority) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this issue.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, issue_title_or_key ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Jira connector is invalid.", - } - fields: dict[str, Any] = {} - if final_summary: - fields["summary"] = final_summary - if final_description is not None: - fields["description"] = { - "type": "doc", - "version": 1, - "content": [ - { - "type": "paragraph", - "content": [{"type": "text", "text": final_description}], + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "jira", } - ], - } - if final_priority: - fields["priority"] = {"name": final_priority} + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} - if not fields: - return {"status": "error", "message": "No changes specified."} + issue_data = context["issue"] + issue_key = issue_data["issue_id"] + document_id = issue_data.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=final_connector_id + result = request_approval( + action_type="jira_issue_update", + tool_name="update_jira_issue", + params={ + "issue_key": issue_key, + "document_id": document_id, + "new_summary": new_summary, + "new_description": new_description, + "new_priority": new_priority, + "connector_id": connector_id_from_context, + }, + context=context, ) - jira_client = await jira_history._get_jira_client() - await asyncio.to_thread( - jira_client.update_issue, final_issue_key, fields - ) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - issue_url = ( - f"{jira_history._base_url}/browse/{final_issue_key}" - if jira_history._base_url and final_issue_key - else "" - ) + final_issue_key = result.params.get("issue_key", issue_key) + final_summary = result.params.get("new_summary", new_summary) + final_description = result.params.get( + "new_description", new_description + ) + final_priority = result.params.get("new_priority", new_priority) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_document_id = result.params.get("document_id", document_id) - kb_message_suffix = "" - if final_document_id: - try: - from app.services.jira import JiraKBSyncService + from sqlalchemy.future import select - kb_service = JiraKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_key, - user_id=user_id, - search_space_id=search_space_id, + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this issue.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } + + fields: dict[str, Any] = {} + if final_summary: + fields["summary"] = final_summary + if final_description is not None: + fields["description"] = { + "type": "doc", + "version": 1, + "content": [ + { + "type": "paragraph", + "content": [ + {"type": "text", "text": final_description} + ], + } + ], + } + if final_priority: + fields["priority"] = {"name": final_priority} + + if not fields: + return {"status": "error", "message": "No changes specified."} + + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + jira_client = await jira_history._get_jira_client() + await asyncio.to_thread( + jira_client.update_issue, final_issue_key, fields + ) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + issue_url = ( + f"{jira_history._base_url}/browse/{final_issue_key}" + if jira_history._base_url and final_issue_key + else "" + ) + + kb_message_suffix = "" + if final_document_id: + try: + from app.services.jira import JiraKBSyncService + + kb_service = JiraKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + issue_id=final_issue_key, + user_id=user_id, + search_space_id=search_space_id, ) - else: + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = ( + " The knowledge base will be updated in the next sync." + ) + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") kb_message_suffix = ( " The knowledge base will be updated in the next sync." ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - return { - "status": "success", - "issue_key": final_issue_key, - "issue_url": issue_url, - "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", - } + return { + "status": "success", + "issue_key": final_issue_key, + "issue_url": issue_url, + "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py index ff254e133..f897bee7a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_create_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the create_linear_issue tool. + """Factory function to create the create_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_create_linear_issue_tool( Returns: Configured create_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def create_linear_issue( @@ -65,7 +73,7 @@ def create_create_linear_issue_tool( """ logger.info(f"create_linear_issue called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -75,160 +83,170 @@ def create_create_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - workspaces = context.get("workspaces", []) - if workspaces and all(w.get("auth_expired") for w in workspaces): - logger.warning("All Linear accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "linear", - } - - logger.info(f"Requesting approval for creating Linear issue: '{title}'") - result = request_approval( - action_type="linear_issue_creation", - tool_name="create_linear_issue", - params={ - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_description = result.params.get("description", description) - final_team_id = result.params.get("team_id") - final_state_id = result.params.get("state_id") - final_assignee_id = result.params.get("assignee_id") - final_priority = result.params.get("priority") - final_label_ids = result.params.get("label_ids") or [] - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return {"status": "error", "message": "Issue title cannot be empty."} - if not final_team_id: - return { - "status": "error", - "message": "A team must be selected to create an issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" + ) + return {"status": "error", "message": context["error"]} + + workspaces = context.get("workspaces", []) + if workspaces and all(w.get("auth_expired") for w in workspaces): + logger.warning("All Linear accounts have expired authentication") + return { + "status": "auth_error", + "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "linear", + } + + logger.info(f"Requesting approval for creating Linear issue: '{title}'") + result = request_approval( + action_type="linear_issue_creation", + tool_name="create_linear_issue", + params={ + "title": title, + "description": description, + "team_id": None, + "state_id": None, + "assignee_id": None, + "priority": None, + "label_ids": [], + "connector_id": connector_id, + }, + context=context, + ) + + if result.rejected: + logger.info("Linear issue creation rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_title = result.params.get("title", title) + final_description = result.params.get("description", description) + final_team_id = result.params.get("team_id") + final_state_id = result.params.get("state_id") + final_assignee_id = result.params.get("assignee_id") + final_priority = result.params.get("priority") + final_label_ids = result.params.get("label_ids") or [] + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_title or not final_title.strip(): + logger.error("Title is empty or contains only whitespace") return { "status": "error", - "message": "No Linear connector found. Please connect Linear in your workspace settings.", + "message": "Issue title cannot be empty.", } - actual_connector_id = connector.id - logger.info(f"Found Linear connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: + if not final_team_id: return { "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", + "message": "A team must be selected to create an issue.", } - logger.info(f"Validated Linear connector: id={actual_connector_id}") - logger.info( - f"Creating Linear issue with final params: title='{final_title}'" - ) - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - result = await linear_client.create_issue( - team_id=final_team_id, - title=final_title, - description=final_description, - state_id=final_state_id, - assignee_id=final_assignee_id, - priority=final_priority, - label_ids=final_label_ids if final_label_ids else None, - ) + from sqlalchemy.future import select - if result.get("status") == "error": - logger.error(f"Failed to create Linear issue: {result.get('message')}") - return {"status": "error", "message": result.get("message")} + from app.db import SearchSourceConnector, SearchSourceConnectorType - logger.info( - f"Linear issue created: {result.get('identifier')} - {result.get('title')}" - ) - - kb_message_suffix = "" - try: - from app.services.linear import LinearKBSyncService - - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=result.get("id"), - issue_identifier=result.get("identifier", ""), - issue_title=result.get("title", final_title), - issue_url=result.get("url"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Linear connector found. Please connect Linear in your workspace settings.", + } + actual_connector_id = connector.id + logger.info(f"Found Linear connector: id={actual_connector_id}") else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + logger.info(f"Validated Linear connector: id={actual_connector_id}") - return { - "status": "success", - "issue_id": result.get("id"), - "identifier": result.get("identifier"), - "url": result.get("url"), - "message": (result.get("message", "") + kb_message_suffix), - } + logger.info( + f"Creating Linear issue with final params: title='{final_title}'" + ) + linear_client = LinearConnector( + session=db_session, connector_id=actual_connector_id + ) + result = await linear_client.create_issue( + team_id=final_team_id, + title=final_title, + description=final_description, + state_id=final_state_id, + assignee_id=final_assignee_id, + priority=final_priority, + label_ids=final_label_ids if final_label_ids else None, + ) + + if result.get("status") == "error": + logger.error( + f"Failed to create Linear issue: {result.get('message')}" + ) + return {"status": "error", "message": result.get("message")} + + logger.info( + f"Linear issue created: {result.get('identifier')} - {result.get('title')}" + ) + + kb_message_suffix = "" + try: + from app.services.linear import LinearKBSyncService + + kb_service = LinearKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + issue_id=result.get("id"), + issue_identifier=result.get("identifier", ""), + issue_title=result.get("title", final_title), + issue_url=result.get("url"), + description=final_description, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "issue_id": result.get("id"), + "identifier": result.get("identifier"), + "url": result.get("url"), + "message": (result.get("message", "") + kb_message_suffix), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py index 29ef0cdf2..c5039a8eb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_delete_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the delete_linear_issue tool. + """Factory function to create the delete_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for finding the correct Linear connector connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_delete_linear_issue_tool( Returns: Configured delete_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def delete_linear_issue( @@ -73,7 +81,7 @@ def create_delete_linear_issue_tool( f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -83,149 +91,152 @@ def create_delete_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for delete context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - issue_identifier = context["issue"].get("identifier", "") - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - logger.info( - f"Requesting approval for deleting Linear issue: '{issue_ref}' " - f"(id={issue_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="linear_issue_deletion", - tool_name="delete_linear_issue", - params={ - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Linear issue with final params: issue_id={final_issue_id}, " - f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_delete_context( + search_space_id, user_id, issue_ref ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + logger.warning(f"Auth expired for delete context: {error_msg}") + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "linear", + } + if "not found" in error_msg.lower(): + logger.warning(f"Issue not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + else: + logger.error(f"Failed to fetch delete context: {error_msg}") + return {"status": "error", "message": error_msg} + + issue_id = context["issue"]["id"] + issue_identifier = context["issue"].get("identifier", "") + document_id = context["issue"]["document_id"] + connector_id_from_context = context.get("workspace", {}).get("id") + + logger.info( + f"Requesting approval for deleting Linear issue: '{issue_ref}' " + f"(id={issue_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="linear_issue_deletion", + tool_name="delete_linear_issue", + params={ + "issue_id": issue_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, + ) + + if result.rejected: + logger.info("Linear issue deletion rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_issue_id = result.params.get("issue_id", issue_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + logger.info( + f"Deleting Linear issue with final params: issue_id={final_issue_id}, " + f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) ) + connector = result.scalars().first() + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + logger.info(f"Validated Linear connector: id={actual_connector_id}") + else: + logger.error("No connector found for this issue") return { "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", + "message": "No connector found for this issue.", } - actual_connector_id = connector.id - logger.info(f"Validated Linear connector: id={actual_connector_id}") - else: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) + linear_client = LinearConnector( + session=db_session, connector_id=actual_connector_id + ) - result = await linear_client.archive_issue(issue_id=final_issue_id) + result = await linear_client.archive_issue(issue_id=final_issue_id) - logger.info( - f"archive_issue result: {result.get('status')} - {result.get('message', '')}" - ) + logger.info( + f"archive_issue result: {result.get('status')} - {result.get('message', '')}" + ) - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from app.db import Document + deleted_from_kb = False + if ( + result.get("status") == "success" + and final_delete_from_kb + and document_id + ): + try: + from app.db import Document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + result["warning"] = ( + f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" - ) - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if issue_identifier: - result["message"] = ( - f"Issue {issue_identifier} archived successfully." - ) - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} Also removed from the knowledge base." - ) + if result.get("status") == "success": + result["deleted_from_kb"] = deleted_from_kb + if issue_identifier: + result["message"] = ( + f"Issue {issue_identifier} archived successfully." + ) + if deleted_from_kb: + result["message"] = ( + f"{result.get('message', '')} Also removed from the knowledge base." + ) - return result + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py index f35d0dddd..d610ce2b7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearKBSyncService, LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_update_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the update_linear_issue tool. + """Factory function to create the update_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_update_linear_issue_tool( Returns: Configured update_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def update_linear_issue( @@ -86,7 +94,7 @@ def create_update_linear_issue_tool( """ logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -96,176 +104,177 @@ def create_update_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for update context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - team = context.get("team", {}) - new_state_id = _resolve_state(team, new_state_name) - new_assignee_id = _resolve_assignee(team, new_assignee_email) - new_label_ids = _resolve_labels(team, new_label_names) - - logger.info( - f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" - ) - result = request_approval( - action_type="linear_issue_update", - tool_name="update_linear_issue", - params={ - "issue_id": issue_id, - "document_id": document_id, - "new_title": new_title, - "new_description": new_description, - "new_state_id": new_state_id, - "new_assignee_id": new_assignee_id, - "new_priority": new_priority, - "new_label_ids": new_label_ids, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_document_id = result.params.get("document_id", document_id) - final_new_title = result.params.get("new_title", new_title) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_state_id = result.params.get("new_state_id", new_state_id) - final_new_assignee_id = result.params.get( - "new_assignee_id", new_assignee_id - ) - final_new_priority = result.params.get("new_priority", new_priority) - final_new_label_ids: list[str] | None = result.params.get( - "new_label_ids", new_label_ids - ) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - if not final_connector_id: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, issue_ref ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={final_connector_id}") - logger.info( - f"Updating Linear issue with final params: issue_id={final_issue_id}" - ) - linear_client = LinearConnector( - session=db_session, connector_id=final_connector_id - ) - updated_issue = await linear_client.update_issue( - issue_id=final_issue_id, - title=final_new_title, - description=final_new_description, - state_id=final_new_state_id, - assignee_id=final_new_assignee_id, - priority=final_new_priority, - label_ids=final_new_label_ids, - ) + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + logger.warning(f"Auth expired for update context: {error_msg}") + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "linear", + } + if "not found" in error_msg.lower(): + logger.warning(f"Issue not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + else: + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - if updated_issue.get("status") == "error": - logger.error( - f"Failed to update Linear issue: {updated_issue.get('message')}" - ) - return { - "status": "error", - "message": updated_issue.get("message"), - } + issue_id = context["issue"]["id"] + document_id = context["issue"]["document_id"] + connector_id_from_context = context.get("workspace", {}).get("id") - logger.info( - f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" - ) + team = context.get("team", {}) + new_state_id = _resolve_state(team, new_state_name) + new_assignee_id = _resolve_assignee(team, new_assignee_email) + new_label_ids = _resolve_labels(team, new_label_names) - if final_document_id is not None: logger.info( - f"Updating knowledge base for document {final_document_id}..." + f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" ) - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_id, - user_id=user_id, - search_space_id=search_space_id, + result = request_approval( + action_type="linear_issue_update", + tool_name="update_linear_issue", + params={ + "issue_id": issue_id, + "document_id": document_id, + "new_title": new_title, + "new_description": new_description, + "new_state_id": new_state_id, + "new_assignee_id": new_assignee_id, + "new_priority": new_priority, + "new_label_ids": new_label_ids, + "connector_id": connector_id_from_context, + }, + context=context, ) - if kb_result["status"] == "success": - logger.info( - f"Knowledge base successfully updated for issue {final_issue_id}" - ) - kb_message = " Your knowledge base has also been updated." - elif kb_result["status"] == "not_indexed": - kb_message = " This issue will be added to your knowledge base in the next scheduled sync." - else: - logger.warning( - f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" - ) - kb_message = " Your knowledge base will be updated in the next scheduled sync." - else: - kb_message = "" - identifier = updated_issue.get("identifier") - default_msg = f"Issue {identifier} updated successfully." - return { - "status": "success", - "identifier": identifier, - "url": updated_issue.get("url"), - "message": f"{updated_issue.get('message', default_msg)}{kb_message}", - } + if result.rejected: + logger.info("Linear issue update rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_issue_id = result.params.get("issue_id", issue_id) + final_document_id = result.params.get("document_id", document_id) + final_new_title = result.params.get("new_title", new_title) + final_new_description = result.params.get( + "new_description", new_description + ) + final_new_state_id = result.params.get("new_state_id", new_state_id) + final_new_assignee_id = result.params.get( + "new_assignee_id", new_assignee_id + ) + final_new_priority = result.params.get("new_priority", new_priority) + final_new_label_ids: list[str] | None = result.params.get( + "new_label_ids", new_label_ids + ) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + + if not final_connector_id: + logger.error("No connector found for this issue") + return { + "status": "error", + "message": "No connector found for this issue.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + logger.info(f"Validated Linear connector: id={final_connector_id}") + + logger.info( + f"Updating Linear issue with final params: issue_id={final_issue_id}" + ) + linear_client = LinearConnector( + session=db_session, connector_id=final_connector_id + ) + updated_issue = await linear_client.update_issue( + issue_id=final_issue_id, + title=final_new_title, + description=final_new_description, + state_id=final_new_state_id, + assignee_id=final_new_assignee_id, + priority=final_new_priority, + label_ids=final_new_label_ids, + ) + + if updated_issue.get("status") == "error": + logger.error( + f"Failed to update Linear issue: {updated_issue.get('message')}" + ) + return { + "status": "error", + "message": updated_issue.get("message"), + } + + logger.info( + f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" + ) + + if final_document_id is not None: + logger.info( + f"Updating knowledge base for document {final_document_id}..." + ) + kb_service = LinearKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + issue_id=final_issue_id, + user_id=user_id, + search_space_id=search_space_id, + ) + if kb_result["status"] == "success": + logger.info( + f"Knowledge base successfully updated for issue {final_issue_id}" + ) + kb_message = " Your knowledge base has also been updated." + elif kb_result["status"] == "not_indexed": + kb_message = " This issue will be added to your knowledge base in the next scheduled sync." + else: + logger.warning( + f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" + ) + kb_message = " Your knowledge base will be updated in the next scheduled sync." + else: + kb_message = "" + + identifier = updated_issue.get("identifier") + default_msg = f"Issue {identifier} updated successfully." + return { + "status": "success", + "identifier": identifier, + "url": updated_issue.get("url"), + "message": f"{updated_issue.get('message', default_msg)}{kb_message}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py index 0a24a988f..65c177d7a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers @@ -17,6 +18,23 @@ def create_create_luma_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_luma_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_luma_event tool + """ + del db_session # per-call session — see docstring + @tool async def create_luma_event( name: str, @@ -40,83 +58,86 @@ def create_create_luma_event_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - result = request_approval( - action_type="luma_create_event", - tool_name="create_luma_event", - params={ - "name": name, - "start_at": start_at, - "end_at": end_at, - "description": description, - "timezone": timezone, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Event was not created.", - } - - final_name = result.params.get("name", name) - final_start = result.params.get("start_at", start_at) - final_end = result.params.get("end_at", end_at) - final_desc = result.params.get("description", description) - final_tz = result.params.get("timezone", timezone) - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - body: dict[str, Any] = { - "name": final_name, - "start_at": final_start, - "end_at": final_end, - "timezone": final_tz, - } - if final_desc: - body["description_md"] = final_desc - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{LUMA_API}/event/create", - headers=headers, - json=body, + result = request_approval( + action_type="luma_create_event", + tool_name="create_luma_event", + params={ + "name": name, + "start_at": start_at, + "end_at": end_at, + "description": description, + "timezone": timezone, + }, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Luma Plus subscription required to create events via API.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Event was not created.", + } - data = resp.json() - event_id = data.get("api_id") or data.get("event", {}).get("api_id") + final_name = result.params.get("name", name) + final_start = result.params.get("start_at", start_at) + final_end = result.params.get("end_at", end_at) + final_desc = result.params.get("description", description) + final_tz = result.params.get("timezone", timezone) - return { - "status": "success", - "event_id": event_id, - "message": f"Event '{final_name}' created on Luma.", - } + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + body: dict[str, Any] = { + "name": final_name, + "start_at": final_start, + "end_at": final_end, + "timezone": final_tz, + } + if final_desc: + body["description_md"] = final_desc + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{LUMA_API}/event/create", + headers=headers, + json=body, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Luma Plus subscription required to create events via API.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", + } + + data = resp.json() + event_id = data.get("api_id") or data.get("event", {}).get("api_id") + + return { + "status": "success", + "event_id": event_id, + "message": f"Event '{final_name}' created on Luma.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py index aec5ad220..6885c2049 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_luma_events_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_luma_events tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_luma_events tool + """ + del db_session # per-call session — see docstring + @tool async def list_luma_events( max_results: int = 25, @@ -28,77 +47,80 @@ def create_list_luma_events_tool( Dictionary with status and a list of events including event_id, name, start_at, end_at, location, url. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} max_results = min(max_results, 50) try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - api_key = get_api_key(connector) - headers = luma_headers(api_key) + api_key = get_api_key(connector) + headers = luma_headers(api_key) - all_entries: list[dict] = [] - cursor = None + all_entries: list[dict] = [] + cursor = None - async with httpx.AsyncClient(timeout=20.0) as client: - while len(all_entries) < max_results: - params: dict[str, Any] = { - "limit": min(100, max_results - len(all_entries)) - } - if cursor: - params["cursor"] = cursor + async with httpx.AsyncClient(timeout=20.0) as client: + while len(all_entries) < max_results: + params: dict[str, Any] = { + "limit": min(100, max_results - len(all_entries)) + } + if cursor: + params["cursor"] = cursor - resp = await client.get( - f"{LUMA_API}/calendar/list-events", - headers=headers, - params=params, + resp = await client.get( + f"{LUMA_API}/calendar/list-events", + headers=headers, + params=params, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } + + data = resp.json() + entries = data.get("entries", []) + if not entries: + break + all_entries.extend(entries) + + next_cursor = data.get("next_cursor") + if not next_cursor: + break + cursor = next_cursor + + events = [] + for entry in all_entries[:max_results]: + ev = entry.get("event", {}) + geo = ev.get("geo_info", {}) + events.append( + { + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + } ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", - } - - data = resp.json() - entries = data.get("entries", []) - if not entries: - break - all_entries.extend(entries) - - next_cursor = data.get("next_cursor") - if not next_cursor: - break - cursor = next_cursor - - events = [] - for entry in all_entries[:max_results]: - ev = entry.get("event", {}) - geo = ev.get("geo_info", {}) - events.append( - { - "event_id": entry.get("api_id"), - "name": ev.get("name", "Untitled"), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location": geo.get("name", ""), - "url": ev.get("url", ""), - "visibility": ev.get("visibility", ""), - } - ) - - return {"status": "success", "events": events, "total": len(events)} + return {"status": "success", "events": events, "total": len(events)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py index b37a9d617..a8484e9c0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_luma_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_luma_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_luma_event tool + """ + del db_session # per-call session — see docstring + @tool async def read_luma_event(event_id: str) -> dict[str, Any]: """Read detailed information about a specific Luma event. @@ -26,60 +45,63 @@ def create_read_luma_event_tool( Dictionary with status and full event details including description, attendees count, meeting URL. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - async with httpx.AsyncClient(timeout=15.0) as client: - resp = await client.get( - f"{LUMA_API}/events/{event_id}", - headers=headers, + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 404: - return { - "status": "not_found", - "message": f"Event '{event_id}' not found.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + f"{LUMA_API}/events/{event_id}", + headers=headers, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code == 404: + return { + "status": "not_found", + "message": f"Event '{event_id}' not found.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } + + data = resp.json() + ev = data.get("event", data) + geo = ev.get("geo_info", {}) + + event_detail = { + "event_id": event_id, + "name": ev.get("name", ""), + "description": ev.get("description", ""), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location_name": geo.get("name", ""), + "address": geo.get("address", ""), + "url": ev.get("url", ""), + "meeting_url": ev.get("meeting_url", ""), + "visibility": ev.get("visibility", ""), + "cover_url": ev.get("cover_url", ""), } - data = resp.json() - ev = data.get("event", data) - geo = ev.get("geo_info", {}) - - event_detail = { - "event_id": event_id, - "name": ev.get("name", ""), - "description": ev.get("description", ""), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location_name": geo.get("name", ""), - "address": geo.get("address", ""), - "url": ev.get("url", ""), - "meeting_url": ev.get("meeting_url", ""), - "visibility": ev.get("visibility", ""), - "cover_url": ev.get("cover_url", ""), - } - - return {"status": "success", "event": event_detail} + return {"status": "success", "event": event_detail} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py index 6efffe960..6ec95e9f0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,17 @@ def create_create_notion_page_tool( """ Factory function to create the create_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Per-call sessions also + keep the request's outer transaction free of long-running Notion API + blocking. + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +39,7 @@ def create_create_notion_page_tool( Returns: Configured create_notion_page tool """ + del db_session # per-call session — see docstring @tool async def create_notion_page( @@ -67,7 +78,7 @@ def create_create_notion_page_tool( """ logger.info(f"create_notion_page called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -77,154 +88,157 @@ def create_create_notion_page_tool( } try: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return { - "status": "error", - "message": context["error"], - } - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Notion accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "notion", - } - - logger.info(f"Requesting approval for creating Notion page: '{title}'") - result = request_approval( - action_type="notion_page_creation", - tool_name="create_notion_page", - params={ - "title": title, - "content": content, - "parent_page_id": None, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) - final_parent_page_id = result.params.get("parent_page_id") - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return { - "status": "error", - "message": "Page title cannot be empty. Please provide a valid title.", - } - - logger.info( - f"Creating Notion page with final params: title='{final_title}'" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - logger.warning( - f"No Notion connector found for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "No Notion connector found. Please connect Notion in your workspace settings.", - } - - actual_connector_id = connector.id - logger.info(f"Found Notion connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: + if "error" in context: logger.error( - f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" + f"Failed to fetch creation context: {context['error']}" ) return { "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + "message": context["error"], } - logger.info(f"Validated Notion connector: id={actual_connector_id}") - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Notion accounts have expired authentication") + return { + "status": "auth_error", + "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "notion", + } - result = await notion_connector.create_page( - title=final_title, - content=final_content, - parent_page_id=final_parent_page_id, - ) - logger.info( - f"create_page result: {result.get('status')} - {result.get('message', '')}" - ) + logger.info(f"Requesting approval for creating Notion page: '{title}'") + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={ + "title": title, + "content": content, + "parent_page_id": None, + "connector_id": connector_id, + }, + context=context, + ) - if result.get("status") == "success": - kb_message_suffix = "" - try: - from app.services.notion import NotionKBSyncService + if result.rejected: + logger.info("Notion page creation rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=result.get("page_id"), - page_title=result.get("title", final_title), - page_url=result.get("url"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) + final_parent_page_id = result.params.get("parent_page_id") + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_title or not final_title.strip(): + logger.error("Title is empty or contains only whitespace") + return { + "status": "error", + "message": "Page title cannot be empty. Please provide a valid title.", + } + + logger.info( + f"Creating Notion page with final params: title='{final_title}'" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, ) - else: + ) + connector = result.scalars().first() + + if not connector: + logger.warning( + f"No Notion connector found for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "No Notion connector found. Please connect Notion in your workspace settings.", + } + + actual_connector_id = connector.id + logger.info(f"Found Notion connector: id={actual_connector_id}") + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + logger.info(f"Validated Notion connector: id={actual_connector_id}") + + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + result = await notion_connector.create_page( + title=final_title, + content=final_content, + parent_page_id=final_parent_page_id, + ) + logger.info( + f"create_page result: {result.get('status')} - {result.get('message', '')}" + ) + + if result.get("status") == "success": + kb_message_suffix = "" + try: + from app.services.notion import NotionKBSyncService + + kb_service = NotionKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + page_id=result.get("page_id"), + page_title=result.get("title", final_title), + page_url=result.get("url"), + content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - result["message"] = result.get("message", "") + kb_message_suffix + result["message"] = result.get("message", "") + kb_message_suffix - return result + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py index 07f7583d2..7b85da4c2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion.tool_metadata_service import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def create_delete_notion_page_tool( """ Factory function to create the delete_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for finding the correct Notion connector connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_delete_notion_page_tool( Returns: Configured delete_notion_page tool """ + del db_session # per-call session — see docstring @tool async def delete_notion_page( @@ -63,7 +71,7 @@ def create_delete_notion_page_tool( f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -73,164 +81,167 @@ def create_delete_notion_page_tool( } try: - # Get page context (page_id, account, title) from indexed data - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + # Get page context (page_id, account, title) from indexed data + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_delete_context( + search_space_id, user_id, page_title ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - page_id = context.get("page_id") - connector_id_from_context = account.get("id") - document_id = context.get("document_id") - - logger.info( - f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" - ) - - result = request_approval( - action_type="notion_page_deletion", - tool_name="delete_notion_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - # Validate the connector - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - # Create connector instance - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - # Delete the page from Notion - result = await notion_connector.delete_page(page_id=final_page_id) - logger.info( - f"delete_page result: {result.get('status')} - {result.get('message', '')}" - ) - - # If deletion was successful and user wants to delete from KB - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from sqlalchemy.future import select - - from app.db import Document - - # Get the document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) + if "error" in context: + error_msg = context["error"] + # Check if it's a "not found" error (softer handling for LLM) + if "not found" in error_msg.lower(): + logger.warning(f"Page not found: {error_msg}") + return { + "status": "not_found", + "message": error_msg, + } else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" - ) + logger.error(f"Failed to fetch delete context: {error_msg}") + return { + "status": "error", + "message": error_msg, + } - # Update result with KB deletion status - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} (also removed from knowledge base)" + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Notion account %s has expired authentication", + account.get("id"), ) + return { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + } - return result + page_id = context.get("page_id") + connector_id_from_context = account.get("id") + document_id = context.get("document_id") + + logger.info( + f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" + ) + + result = request_approval( + action_type="notion_page_deletion", + tool_name="delete_notion_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, + ) + + if result.rejected: + logger.info("Notion page deletion rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + logger.info( + f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + # Validate the connector + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + actual_connector_id = connector.id + logger.info(f"Validated Notion connector: id={actual_connector_id}") + else: + logger.error("No connector found for this page") + return { + "status": "error", + "message": "No connector found for this page.", + } + + # Create connector instance + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + # Delete the page from Notion + result = await notion_connector.delete_page(page_id=final_page_id) + logger.info( + f"delete_page result: {result.get('status')} - {result.get('message', '')}" + ) + + # If deletion was successful and user wants to delete from KB + deleted_from_kb = False + if ( + result.get("status") == "success" + and final_delete_from_kb + and document_id + ): + try: + from sqlalchemy.future import select + + from app.db import Document + + # Get the document + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + result["warning"] = ( + f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" + ) + + # Update result with KB deletion status + if result.get("status") == "success": + result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + result["message"] = ( + f"{result.get('message', '')} (also removed from knowledge base)" + ) + + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py index 85c08177c..df757476a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def create_update_notion_page_tool( """ Factory function to create the update_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache (see + ``create_create_notion_page_tool`` for the full rationale). + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_update_notion_page_tool( Returns: Configured update_notion_page tool """ + del db_session # per-call session — see docstring @tool async def update_notion_page( @@ -71,7 +79,7 @@ def create_update_notion_page_tool( f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -88,152 +96,155 @@ def create_update_notion_page_tool( } try: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - - page_id = context.get("page_id") - document_id = context.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - logger.info( - f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" - ) - result = request_approval( - action_type="notion_page_update", - tool_name="update_notion_page", - params={ - "page_id": page_id, - "content": content, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_content = result.params.get("content", content) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - logger.info( - f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - result = await notion_connector.update_page( - page_id=final_page_id, - content=final_content, - ) - logger.info( - f"update_page result: {result.get('status')} - {result.get('message', '')}" - ) - - if result.get("status") == "success" and document_id is not None: - from app.services.notion import NotionKBSyncService - - logger.info(f"Updating knowledge base for document {document_id}...") - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - appended_content=final_content, - user_id=user_id, - search_space_id=search_space_id, - appended_block_ids=result.get("appended_block_ids"), + async with async_session_maker() as db_session: + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, page_title ) - if kb_result["status"] == "success": - result["message"] = ( - f"{result['message']}. Your knowledge base has also been updated." - ) - logger.info( - f"Knowledge base successfully updated for page {final_page_id}" - ) - elif kb_result["status"] == "not_indexed": - result["message"] = ( - f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." - ) - else: - result["message"] = ( - f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." - ) + if "error" in context: + error_msg = context["error"] + # Check if it's a "not found" error (softer handling for LLM) + if "not found" in error_msg.lower(): + logger.warning(f"Page not found: {error_msg}") + return { + "status": "not_found", + "message": error_msg, + } + else: + logger.error(f"Failed to fetch update context: {error_msg}") + return { + "status": "error", + "message": error_msg, + } + + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"KB update failed for page {final_page_id}: {kb_result['message']}" + "Notion account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + } + + page_id = context.get("page_id") + document_id = context.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") + + logger.info( + f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" + ) + result = request_approval( + action_type="notion_page_update", + tool_name="update_notion_page", + params={ + "page_id": page_id, + "content": content, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: + logger.info("Notion page update rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_page_id = result.params.get("page_id", page_id) + final_content = result.params.get("content", content) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + + logger.info( + f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + actual_connector_id = connector.id + logger.info(f"Validated Notion connector: id={actual_connector_id}") + else: + logger.error("No connector found for this page") + return { + "status": "error", + "message": "No connector found for this page.", + } + + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + result = await notion_connector.update_page( + page_id=final_page_id, + content=final_content, + ) + logger.info( + f"update_page result: {result.get('status')} - {result.get('message', '')}" + ) + + if result.get("status") == "success" and document_id is not None: + from app.services.notion import NotionKBSyncService + + logger.info( + f"Updating knowledge base for document {document_id}..." + ) + kb_service = NotionKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=document_id, + appended_content=final_content, + user_id=user_id, + search_space_id=search_space_id, + appended_block_ids=result.get("appended_block_ids"), ) - return result + if kb_result["status"] == "success": + result["message"] = ( + f"{result['message']}. Your knowledge base has also been updated." + ) + logger.info( + f"Knowledge base successfully updated for page {final_page_id}" + ) + elif kb_result["status"] == "not_indexed": + result["message"] = ( + f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." + ) + else: + result["message"] = ( + f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." + ) + logger.warning( + f"KB update failed for page {final_page_id}: {kb_result['message']}" + ) + + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index 21272e01d..5f199a41b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -10,7 +10,7 @@ from sqlalchemy.future import select from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -48,6 +48,23 @@ def create_create_onedrive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_onedrive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_onedrive_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_onedrive_file( name: str, @@ -70,173 +87,178 @@ def create_create_onedrive_file_tool( """ logger.info(f"create_onedrive_file called: name='{name}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "OneDrive tool not properly configured.", } try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected OneDrive accounts need re-authentication.", - "connector_type": "onedrive", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = OneDriveClient(session=db_session, connector_id=cid) - items, err = await client.list_children("root") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - {"folder_id": item["id"], "name": item["name"]} - for item in items - if item.get("folder") is not None - and item.get("id") - and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", cid, exc_info=True - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - } - - result = request_approval( - action_type="onedrive_file_creation", - tool_name="create_onedrive_file", - params={ - "name": name, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_docx_extension(final_name) - - if final_connector_id is not None: + async with async_session_maker() as db_session: result = await db_session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) - connector = result.scalars().first() - else: - connector = connectors[0] + connectors = result.scalars().all() - if not connector: - return { - "status": "error", - "message": "Selected OneDrive connector is invalid.", + if not connectors: + return { + "status": "error", + "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", + } + + accounts = [] + for c in connectors: + cfg = c.config or {} + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) + + if all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected OneDrive accounts need re-authentication.", + "connector_type": "onedrive", + } + + parent_folders: dict[int, list[dict[str, str]]] = {} + for acc in accounts: + cid = acc["id"] + if acc.get("auth_expired"): + parent_folders[cid] = [] + continue + try: + client = OneDriveClient(session=db_session, connector_id=cid) + items, err = await client.list_children("root") + if err: + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) + parent_folders[cid] = [] + else: + parent_folders[cid] = [ + {"folder_id": item["id"], "name": item["name"]} + for item in items + if item.get("folder") is not None + and item.get("id") + and item.get("name") + ] + except Exception: + logger.warning( + "Error fetching folders for connector %s", + cid, + exc_info=True, + ) + parent_folders[cid] = [] + + context: dict[str, Any] = { + "accounts": accounts, + "parent_folders": parent_folders, } - docx_bytes = _markdown_to_docx(final_content or "") - - client = OneDriveClient(session=db_session, connector_id=connector.id) - created = await client.create_file( - name=final_name, - parent_id=final_parent_folder_id, - content=docx_bytes, - mime_type=DOCX_MIME, - ) - - logger.info( - f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.onedrive import OneDriveKBSyncService - - kb_service = OneDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=DOCX_MIME, - web_url=created.get("webUrl"), - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, + result = request_approval( + action_type="onedrive_file_creation", + tool_name="create_onedrive_file", + params={ + "name": name, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_url": created.get("webUrl"), - "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_name = result.params.get("name", name) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") + + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + final_name = _ensure_docx_extension(final_name) + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + connector = result.scalars().first() + else: + connector = connectors[0] + + if not connector: + return { + "status": "error", + "message": "Selected OneDrive connector is invalid.", + } + + docx_bytes = _markdown_to_docx(final_content or "") + + client = OneDriveClient(session=db_session, connector_id=connector.id) + created = await client.create_file( + name=final_name, + parent_id=final_parent_folder_id, + content=docx_bytes, + mime_type=DOCX_MIME, + ) + + logger.info( + f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" + ) + + kb_message_suffix = "" + try: + from app.services.onedrive import OneDriveKBSyncService + + kb_service = OneDriveKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=created.get("id"), + file_name=created.get("name", final_name), + mime_type=DOCX_MIME, + web_url=created.get("webUrl"), + content=final_content, + connector_id=connector.id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": created.get("id"), + "name": created.get("name"), + "web_url": created.get("webUrl"), + "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index a7f13b5df..4857ea988 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -13,6 +13,7 @@ from app.db import ( DocumentType, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) logger = logging.getLogger(__name__) @@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_onedrive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_onedrive_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_onedrive_file( file_name: str, @@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool( f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "OneDrive tool not properly configured.", } try: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: + async with async_session_maker() as db_session: doc_result = await db_session.execute( select(Document) .join( @@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool( and_( Document.search_space_id == search_space_id, Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower( - cast( - Document.document_metadata["onedrive_file_name"], - String, - ) - ) - == func.lower(file_name), + func.lower(Document.title) == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool( ) document = doc_result.scalars().first() - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed OneDrive files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_id = meta.get("onedrive_file_id") - document_id = document.id - - if not file_id: - return { - "status": "error", - "message": "File ID is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + if not document: + doc_result = await db_session.execute( + select(Document) + .join( + SearchSourceConnector, + Document.connector_id == SearchSourceConnector.id, + ) + .filter( + and_( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.ONEDRIVE_FILE, + func.lower( + cast( + Document.document_metadata[ + "onedrive_file_name" + ], + String, + ) + ) + == func.lower(file_name), + SearchSourceConnector.user_id == user_id, + ) + ) + .order_by(Document.updated_at.desc().nullslast()) + .limit(1) ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "OneDrive connector not found or access denied.", - } + document = doc_result.scalars().first() - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "onedrive", - } + if not document: + return { + "status": "not_found", + "message": ( + f"File '{file_name}' not found in your indexed OneDrive files. " + "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " + "or (3) the file name is different." + ), + } - context = { - "file": { - "file_id": file_id, - "name": file_name, - "document_id": document_id, - "web_url": meta.get("web_url"), - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } + if not document.connector_id: + return { + "status": "error", + "message": "Document has no associated connector.", + } - result = request_approval( - action_type="onedrive_file_trash", - tool_name="delete_onedrive_file", - params={ - "file_id": file_id, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + meta = document.document_metadata or {} + file_id = meta.get("onedrive_file_id") + document_id = document.id - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if not file_id: + return { + "status": "error", + "message": "File ID is missing. Please re-index the file.", + } - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if final_connector_id != connector.id: - result = await db_session.execute( + conn_result = await db_session.execute( select(SearchSourceConnector).filter( and_( - SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type @@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool( ) ) ) - validated_connector = result.scalars().first() - if not validated_connector: + connector = conn_result.scalars().first() + if not connector: return { "status": "error", - "message": "Selected OneDrive connector is invalid or has been disconnected.", + "message": "OneDrive connector not found or access denied.", } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - logger.info( - f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" - ) + cfg = connector.config or {} + if cfg.get("auth_expired"): + return { + "status": "auth_error", + "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "onedrive", + } - client = OneDriveClient( - session=db_session, connector_id=actual_connector_id - ) - await client.trash_file(final_file_id) + context = { + "file": { + "file_id": file_id, + "name": file_name, + "document_id": document_id, + "web_url": meta.get("web_url"), + }, + "account": { + "id": connector.id, + "name": connector.name, + "user_email": cfg.get("user_email"), + }, + } - logger.info( - f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" - ) - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file_name}' to the recycle bin.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + result = request_approval( + action_type="onedrive_file_trash", + tool_name="delete_onedrive_file", + params={ + "file_id": file_id, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if final_connector_id != connector.id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + and_( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id + == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + ) + validated_connector = result.scalars().first() + if not validated_connector: + return { + "status": "error", + "message": "Selected OneDrive connector is invalid or has been disconnected.", + } + actual_connector_id = validated_connector.id + else: + actual_connector_id = connector.id + + logger.info( + f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" + ) + + client = OneDriveClient( + session=db_session, connector_id=actual_connector_id + ) + await client.trash_file(final_file_id) + + logger.info( + f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" + ) + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": final_file_id, + "message": f"Successfully moved '{file_name}' to the recycle bin.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + doc = doc_result.scalars().first() + if doc: + await db_session.delete(doc) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index e8bab36fd..b842d7a20 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -824,13 +824,22 @@ async def build_tools_async( """Async version of build_tools that also loads MCP tools from database. Design Note: - This function exists because MCP tools require database queries to load user configs, - while built-in tools are created synchronously from static code. + This function exists because MCP tools require database queries to load + user configs, while built-in tools are created synchronously from static + code. - Alternative: We could make build_tools() itself async and always query the database, - but that would force async everywhere even when only using built-in tools. The current - design keeps the simple case (static tools only) synchronous while supporting dynamic - database-loaded tools through this async wrapper. + Alternative: We could make build_tools() itself async and always query + the database, but that would force async everywhere even when only using + built-in tools. The current design keeps the simple case (static tools + only) synchronous while supporting dynamic database-loaded tools through + this async wrapper. + + Phase 1.3: built-in tool construction (CPU; runs in a thread pool to + avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on + the event loop) are kicked off concurrently. Cold-path savings are + bounded by the slower of the two — typically MCP at ~200ms-1.7s — + so the parallelization recovers the ~50-200ms previously spent + serially on built-in construction. Args: dependencies: Dict containing all possible dependencies @@ -843,33 +852,70 @@ async def build_tools_async( List of configured tool instances ready for the agent, including MCP tools. """ + import asyncio import time _perf_log = logging.getLogger("surfsense.perf") _perf_log.setLevel(logging.DEBUG) + can_load_mcp = ( + include_mcp_tools + and "db_session" in dependencies + and "search_space_id" in dependencies + ) + + # Built-in tool construction is synchronous + CPU-only. Off-loop it so + # MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure + # function over its inputs — safe to thread-shift. _t0 = time.perf_counter() - tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + builtin_task = asyncio.create_task( + asyncio.to_thread( + build_tools, dependencies, enabled_tools, disabled_tools, additional_tools + ) + ) + + mcp_task: asyncio.Task | None = None + if can_load_mcp: + mcp_task = asyncio.create_task( + load_mcp_tools( + dependencies["db_session"], + dependencies["search_space_id"], + ) + ) + + # Surface failures from each task independently so a flaky MCP + # endpoint never poisons built-in tool registration. ``return_exceptions`` + # gives us per-task exceptions instead of dropping the second result + # when the first raises. + if mcp_task is not None: + builtin_result, mcp_result = await asyncio.gather( + builtin_task, mcp_task, return_exceptions=True + ) + else: + builtin_result = await builtin_task + mcp_result = None + + if isinstance(builtin_result, BaseException): + raise builtin_result # built-in registration failure is non-recoverable + tools: list[BaseTool] = builtin_result _perf_log.info( - "[build_tools_async] Built-in tools in %.3fs (%d tools)", + "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)", time.perf_counter() - _t0, len(tools), ) - # Load MCP tools if requested and dependencies are available - if ( - include_mcp_tools - and "db_session" in dependencies - and "search_space_id" in dependencies - ): - try: - _t0 = time.perf_counter() - mcp_tools = await load_mcp_tools( - dependencies["db_session"], - dependencies["search_space_id"], + if mcp_task is not None: + if isinstance(mcp_result, BaseException): + # ``return_exceptions=True`` captures the exception out-of-band, + # so ``sys.exc_info()`` is empty here. Pass the captured + # exception via ``exc_info=`` to get a real traceback. + logging.error( + "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result ) + else: + mcp_tools = mcp_result or [] _perf_log.info( - "[build_tools_async] MCP tools loaded in %.3fs (%d tools)", + "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)", time.perf_counter() - _t0, len(mcp_tools), ) @@ -879,8 +925,6 @@ async def build_tools_async( len(mcp_tools), [t.name for t in mcp_tools], ) - except Exception as e: - logging.exception("Failed to load MCP tools: %s", e) logging.info( "Total tools for agent: %d — %s", diff --git a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py index b8b1527c7..2965f2f02 100644 --- a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py +++ b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py @@ -15,7 +15,7 @@ from langchain_core.tools import tool from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument +from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker from app.utils.document_converters import embed_text @@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession): """ Factory function to create the search_surfsense_docs tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + Args: - db_session: Database session for executing queries + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. Returns: A configured tool function for searching Surfsense documentation """ + del db_session # per-call session — see docstring @tool async def search_surfsense_docs(query: str, top_k: int = 10) -> str: @@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession): Returns: Relevant documentation content formatted with chunk IDs for citations """ - return await search_surfsense_docs_async( - query=query, - db_session=db_session, - top_k=top_k, - ) + async with async_session_maker() as db_session: + return await search_surfsense_docs_async( + query=query, + db_session=db_session, + top_k=top_k, + ) return search_surfsense_docs diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py index d7b000853..0fc52b5c7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import GRAPH_API, get_access_token, get_teams_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_teams_channels_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_teams_channels tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_teams_channels tool + """ + del db_session # per-call session — see docstring + @tool async def list_teams_channels() -> dict[str, Any]: """List all Microsoft Teams and their channels the user has access to. @@ -23,63 +42,66 @@ def create_list_teams_channels_tool( Dictionary with status and a list of teams, each containing team_id, team_name, and a list of channels (id, name). """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - headers = {"Authorization": f"Bearer {token}"} - - async with httpx.AsyncClient(timeout=20.0) as client: - teams_resp = await client.get( - f"{GRAPH_API}/me/joinedTeams", headers=headers + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - if teams_resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if teams_resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {teams_resp.status_code}", - } + token = await get_access_token(db_session, connector) + headers = {"Authorization": f"Bearer {token}"} - teams_data = teams_resp.json().get("value", []) - result_teams = [] - - async with httpx.AsyncClient(timeout=20.0) as client: - for team in teams_data: - team_id = team["id"] - ch_resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels", - headers=headers, - ) - channels = [] - if ch_resp.status_code == 200: - channels = [ - {"id": ch["id"], "name": ch.get("displayName", "")} - for ch in ch_resp.json().get("value", []) - ] - result_teams.append( - { - "team_id": team_id, - "team_name": team.get("displayName", ""), - "channels": channels, - } + async with httpx.AsyncClient(timeout=20.0) as client: + teams_resp = await client.get( + f"{GRAPH_API}/me/joinedTeams", headers=headers ) - return { - "status": "success", - "teams": result_teams, - "total_teams": len(result_teams), - } + if teams_resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if teams_resp.status_code != 200: + return { + "status": "error", + "message": f"Graph API error: {teams_resp.status_code}", + } + + teams_data = teams_resp.json().get("value", []) + result_teams = [] + + async with httpx.AsyncClient(timeout=20.0) as client: + for team in teams_data: + team_id = team["id"] + ch_resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels", + headers=headers, + ) + channels = [] + if ch_resp.status_code == 200: + channels = [ + {"id": ch["id"], "name": ch.get("displayName", "")} + for ch in ch_resp.json().get("value", []) + ] + result_teams.append( + { + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + } + ) + + return { + "status": "success", + "teams": result_teams, + "total_teams": len(result_teams), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py index d24a7e4d3..0ebda021e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import GRAPH_API, get_access_token, get_teams_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_teams_messages_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_teams_messages tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_teams_messages tool + """ + del db_session # per-call session — see docstring + @tool async def read_teams_messages( team_id: str, @@ -32,65 +51,68 @@ def create_read_teams_messages_tool( Dictionary with status and a list of messages including id, sender, content, timestamp. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} limit = min(limit, 50) try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", - headers={"Authorization": f"Bearer {token}"}, - params={"$top": limit}, + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Insufficient permissions to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {resp.status_code}", - } + token = await get_access_token(db_session, connector) - raw_msgs = resp.json().get("value", []) - messages = [] - for m in raw_msgs: - sender = m.get("from", {}) - user_info = sender.get("user", {}) if sender else {} - body = m.get("body", {}) - messages.append( - { - "id": m.get("id"), - "sender": user_info.get("displayName", "Unknown"), - "content": body.get("content", ""), - "content_type": body.get("contentType", "text"), - "timestamp": m.get("createdDateTime", ""), + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", + headers={"Authorization": f"Bearer {token}"}, + params={"$top": limit}, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Insufficient permissions to read this channel.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Graph API error: {resp.status_code}", } - ) - return { - "status": "success", - "team_id": team_id, - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } + raw_msgs = resp.json().get("value", []) + messages = [] + for m in raw_msgs: + sender = m.get("from", {}) + user_info = sender.get("user", {}) if sender else {} + body = m.get("body", {}) + messages.append( + { + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + } + ) + + return { + "status": "success", + "team_id": team_id, + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py index fd8d00870..6f40d27e1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import GRAPH_API, get_access_token, get_teams_connector @@ -17,6 +18,23 @@ def create_send_teams_message_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_teams_message tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_teams_message tool + """ + del db_session # per-call session — see docstring + @tool async def send_teams_message( team_id: str, @@ -39,70 +57,73 @@ def create_send_teams_message_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - result = request_approval( - action_type="teams_send_message", - tool_name="send_teams_message", - params={ - "team_id": team_id, - "channel_id": channel_id, - "content": content, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_team = result.params.get("team_id", team_id) - final_channel = result.params.get("channel_id", channel_id) - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", + result = request_approval( + action_type="teams_send_message", + tool_name="send_teams_message", + params={ + "team_id": team_id, + "channel_id": channel_id, + "content": content, }, - json={"body": {"content": final_content}}, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "insufficient_permissions", - "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": "Message sent to Teams channel.", - } + final_content = result.params.get("content", content) + final_team = result.params.get("team_id", team_id) + final_channel = result.params.get("channel_id", channel_id) + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"body": {"content": final_content}}, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if resp.status_code == 403: + return { + "status": "insufficient_permissions", + "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", + } + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": "Message sent to Teams channel.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py index 4128ac0dc..fbc9edbba 100644 --- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py +++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py @@ -26,7 +26,7 @@ from langchain_core.tools import tool from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SearchSpace, User +from app.db import SearchSpace, User, async_session_maker logger = logging.getLogger(__name__) @@ -295,6 +295,25 @@ def create_update_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): + """Factory function to create the user-memory update tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + The session's bound ``commit``/``rollback`` methods are captured at + call time, after ``async with`` has bound ``db_session`` locally. + + Args: + user_id: ID of the user whose memory document is being updated. + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + llm: Optional LLM for the forced-rewrite path. + + Returns: + Configured update_memory tool for the user-memory scope. + """ + del db_session # per-call session — see docstring uid = UUID(user_id) if isinstance(user_id, str) else user_id @tool @@ -311,26 +330,26 @@ def create_update_memory_tool( updated_memory: The FULL updated markdown document (not a diff). """ try: - result = await db_session.execute(select(User).where(User.id == uid)) - user = result.scalars().first() - if not user: - return {"status": "error", "message": "User not found."} + async with async_session_maker() as db_session: + result = await db_session.execute(select(User).where(User.id == uid)) + user = result.scalars().first() + if not user: + return {"status": "error", "message": "User not found."} - old_memory = user.memory_md + old_memory = user.memory_md - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, - llm=llm, - apply_fn=lambda content: setattr(user, "memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="memory", - scope="user", - ) + return await _save_memory( + updated_memory=updated_memory, + old_memory=old_memory, + llm=llm, + apply_fn=lambda content: setattr(user, "memory_md", content), + commit_fn=db_session.commit, + rollback_fn=db_session.rollback, + label="memory", + scope="user", + ) except Exception as e: logger.exception("Failed to update user memory: %s", e) - await db_session.rollback() return { "status": "error", "message": f"Failed to update memory: {e}", @@ -344,6 +363,27 @@ def create_update_team_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): + """Factory function to create the team-memory update tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + The session's bound ``commit``/``rollback`` methods are captured at + call time, after ``async with`` has bound ``db_session`` locally. + + Args: + search_space_id: ID of the search space whose team memory is being + updated. + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + llm: Optional LLM for the forced-rewrite path. + + Returns: + Configured update_memory tool for the team-memory scope. + """ + del db_session # per-call session — see docstring + @tool async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the team's shared memory document for this search space. @@ -359,28 +399,30 @@ def create_update_team_memory_tool( updated_memory: The FULL updated markdown document (not a diff). """ try: - result = await db_session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - space = result.scalars().first() - if not space: - return {"status": "error", "message": "Search space not found."} + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + space = result.scalars().first() + if not space: + return {"status": "error", "message": "Search space not found."} - old_memory = space.shared_memory_md + old_memory = space.shared_memory_md - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, - llm=llm, - apply_fn=lambda content: setattr(space, "shared_memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="team memory", - scope="team", - ) + return await _save_memory( + updated_memory=updated_memory, + old_memory=old_memory, + llm=llm, + apply_fn=lambda content: setattr( + space, "shared_memory_md", content + ), + commit_fn=db_session.commit, + rollback_fn=db_session.rollback, + label="team memory", + scope="team", + ) except Exception as e: logger.exception("Failed to update team memory: %s", e) - await db_session.rollback() return { "status": "error", "message": f"Failed to update team memory: {e}", diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 016c2de42..08194e7fb 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -31,6 +31,7 @@ from app.config import ( initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session @@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None: OpenRouterIntegrationService.get_instance().stop_background_refresh() +async def _warm_agent_jit_caches() -> None: + """Pay the LangChain / LangGraph / Deepagents JIT cost at startup. + + Why + ---- + A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema + generation chain takes 1.5-2 seconds of pure CPU on first invocation + inside any Python process: the graph compiler builds reducers, + Pydantic v2 generates and JITs validator schemas, deepagents + eagerly compiles its general-purpose subagent, etc. Subsequent + compiles in the same process pay only ~50% of that cost (the lazy + JIT bits are cached in module-level dicts). + + Doing one throwaway compile during ``lifespan`` startup pre-pays + that cost so the *first real request* doesn't. We do NOT prime + :mod:`agent_cache` because the cache key requires real + ``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the + throwaway agent is genuinely thrown away and immediately collected. + + Safety + ------ + * No DB access. We construct a stub LLM (no real keys), pass an + empty tools list, and pass ``checkpointer=None`` so we never + touch Postgres. + * Bounded by ``asyncio.wait_for`` so a hang here can never block + worker startup. On any failure, we log + swallow — the worst + case is the first real request pays the full cold cost (i.e. + pre-warmup behaviour). + """ + import time as _time + + logger = logging.getLogger(__name__) + t0 = _time.perf_counter() + try: + from langchain.agents import create_agent + from langchain.agents.middleware import ( + ModelCallLimitMiddleware, + TodoListMiddleware, + ToolCallLimitMiddleware, + ) + from langchain_core.language_models.fake_chat_models import ( + FakeListChatModel, + ) + from langchain_core.tools import tool + + from app.agents.new_chat.context import SurfSenseContextSchema + + # Minimal LLM stub. ``FakeListChatModel`` satisfies + # ``BaseChatModel`` without any network or auth — perfect for + # exercising the compile path without side effects. + stub_llm = FakeListChatModel(responses=["warmup-response"]) + + # Two trivial tools with arg + return schemas — exercises the + # Pydantic v2 schema JIT path. Without at least one tool the + # graph compile skips the tool-loop bytecode generation that + # accounts for ~30-50% of cold compile cost. + @tool + def _warmup_tool_a(query: str, limit: int = 5) -> str: + """Warmup tool A — never actually invoked.""" + return query[:limit] + + @tool + def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]: + """Warmup tool B — never actually invoked.""" + return {"name": name, "value": value} + + # A handful of common middleware so the compile pre-pays the + # ``AgentMiddleware`` resolver path. These instances never run + # because the throwaway agent is immediately collected. + # ``SubAgentMiddleware`` is the single heaviest line in cold + # ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to + # compile its general-purpose subagent's full inner graph), + # so we include it here to make sure that compile path is JIT'd. + warmup_middleware: list = [ + TodoListMiddleware(), + ModelCallLimitMiddleware( + thread_limit=120, run_limit=80, exit_behavior="end" + ), + ToolCallLimitMiddleware( + thread_limit=300, run_limit=80, exit_behavior="continue" + ), + ] + try: + from deepagents import SubAgentMiddleware + from deepagents.backends import StateBackend + from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT + + gp_warmup_spec = { # type: ignore[var-annotated] + **GENERAL_PURPOSE_SUBAGENT, + "model": stub_llm, + "tools": [_warmup_tool_a], + "middleware": [TodoListMiddleware()], + } + warmup_middleware.append( + SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec]) + ) + except Exception: + # Deepagents missing/incompatible — middleware-only warmup + # still produces a useful (smaller) speedup. + logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True) + + compiled = create_agent( + stub_llm, + tools=[_warmup_tool_a, _warmup_tool_b], + system_prompt="You are a warmup stub.", + middleware=warmup_middleware, + context_schema=SurfSenseContextSchema, + checkpointer=None, + ) + + # Touch the compiled graph's stream_channels / nodes so any + # remaining lazy schema work fires now instead of on first + # real invocation. + _ = list(getattr(compiled, "nodes", {}).keys()) + + del compiled + logger.info( + "[startup] Agent JIT warmup completed in %.3fs", + _time.perf_counter() - t0, + ) + except Exception: + logger.warning( + "[startup] Agent JIT warmup failed in %.3fs (non-fatal — first " + "real request will pay the full compile cost)", + _time.perf_counter() - t0, + exc_info=True, + ) + + @asynccontextmanager async def lifespan(app: FastAPI): # Tune GC: lower gen-2 threshold so long-lived garbage is collected @@ -432,6 +562,7 @@ async def lifespan(app: FastAPI): await setup_checkpointer_tables() initialize_openrouter_integration() _start_openrouter_background_refresh() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() @@ -443,6 +574,18 @@ async def lifespan(app: FastAPI): "Docs will be indexed on the next restart." ) + # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays + # worker readiness. ``shield`` so Uvicorn cancelling startup + # doesn't leave half-warmed Pydantic schemas in an inconsistent + # state. + try: + await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20) + except (TimeoutError, Exception): # pragma: no cover - defensive + logging.getLogger(__name__).warning( + "[startup] Agent JIT warmup hit timeout/error — skipping; " + "first real request will pay the full compile cost." + ) + log_system_snapshot("startup_complete") yield @@ -452,6 +595,23 @@ async def lifespan(app: FastAPI): def registration_allowed(): + """Master auth kill switch keyed on the REGISTRATION_ENABLED env var. + + Despite the name, this dependency does NOT only gate registration. When + REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface + that could mint or refresh a session for an attacker: + + * email/password ``POST /auth/register`` + * email/password ``POST /auth/jwt/login`` + * the Google OAuth router (``/auth/google/authorize`` and the shared + ``/auth/google/callback`` handles both new signups and login for + existing users, so flipping this off locks both) + * the bespoke ``/auth/google/authorize-redirect`` helper used by the UI + + Use it as a temporary "freeze all new sessions" lever during incident + response. It is not a way to disable signup while keeping login working; + for that, override ``UserManager.oauth_callback`` instead. + """ if not config.REGISTRATION_ENABLED: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled" @@ -596,32 +756,45 @@ app.add_middleware( allow_headers=["*"], # Allows all headers ) -app.include_router( - fastapi_users.get_auth_router(auth_backend), - prefix="/auth/jwt", - tags=["auth"], - dependencies=[Depends(rate_limit_login)], -) -app.include_router( - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], - dependencies=[ - Depends(rate_limit_register), - Depends(registration_allowed), # blocks registration when disabled - ], -) -app.include_router( - fastapi_users.get_reset_password_router(), - prefix="/auth", - tags=["auth"], - dependencies=[Depends(rate_limit_password_reset)], -) -app.include_router( - fastapi_users.get_verify_router(UserRead), - prefix="/auth", - tags=["auth"], -) +# Password / email-based auth routers are only mounted when not running in +# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left +# POST /auth/register reachable, which is the bypass that allowed bots to +# create non-OAuth users in spite of AUTH_TYPE=GOOGLE. +if config.AUTH_TYPE != "GOOGLE": + app.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth/jwt", + tags=["auth"], + dependencies=[ + Depends(rate_limit_login), + Depends( + registration_allowed + ), # honour REGISTRATION_ENABLED kill switch on login too + ], + ) + app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + dependencies=[ + Depends(rate_limit_register), + Depends(registration_allowed), + ], + ) + app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], + dependencies=[Depends(rate_limit_password_reset)], + ) + app.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], + ) + +# /users/me (read/update profile) is needed in every auth mode, so it stays +# mounted unconditionally. app.include_router( fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", @@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE": ), prefix="/auth/google", tags=["auth"], - dependencies=[ - Depends(registration_allowed) - ], # blocks OAuth registration when disabled + # REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE + # it blocks BOTH new OAuth signups AND login of existing OAuth users + # (the fastapi-users OAuth router shares one callback for create+login, + # so this dependency closes both paths together). + dependencies=[Depends(registration_allowed)], ) # Add a redirect-based authorize endpoint for Firefox/Safari compatibility # This endpoint performs a server-side redirect instead of returning JSON # which fixes cross-site cookie issues where browsers don't send cookies - # set via cross-origin fetch requests on subsequent redirects - @app.get("/auth/google/authorize-redirect", tags=["auth"]) + # set via cross-origin fetch requests on subsequent redirects. + # The registration_allowed dependency mirrors the OAuth router above so + # the kill switch fails fast here instead of bouncing users to Google + # only to 403 on the callback. + @app.get( + "/auth/google/authorize-redirect", + tags=["auth"], + dependencies=[Depends(registration_allowed)], + ) async def google_authorize_redirect( request: Request, ): diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 58a8b0f39..74710d5e1 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -22,10 +22,12 @@ def init_worker(**kwargs): initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) initialize_openrouter_integration() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 86482767d..f6f0c7f62 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -47,11 +47,37 @@ def load_global_llm_configs(): data = yaml.safe_load(f) configs = data.get("global_llm_configs", []) + # Lazy import keeps the `app.config` -> `app.services` edge one-way + # and matches the `provider_api_base` pattern used elsewhere. + from app.services.provider_capabilities import derive_supports_image_input + seen_slugs: dict[str, int] = {} for cfg in configs: cfg.setdefault("billing_tier", "free") cfg.setdefault("anonymous_enabled", False) cfg.setdefault("seo_enabled", False) + # Capability flag: explicit YAML override always wins. When the + # operator has not annotated the model, defer to LiteLLM's + # authoritative model map (`supports_vision`) which already + # knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are + # vision-capable. Unknown / unmapped models default-allow so + # we don't lock the user out of a freshly added third-party + # entry; the streaming-task safety net (driven by + # `is_known_text_only_chat_model`) is the only place a False + # actually blocks a request. + if "supports_image_input" not in cfg: + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + cfg["supports_image_input"] = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) if cfg.get("seo_enabled") and cfg.get("seo_slug"): slug = cfg["seo_slug"] @@ -63,6 +89,27 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) + # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Tier A — operator-curated, locked first when premium-eligible. + # The OpenRouter refresh tick later re-stamps health for any cfg + # whose provider == "OPENROUTER" via _enrich_health. + try: + from app.services.quality_score import static_score_yaml + + for cfg in configs: + cfg["auto_pin_tier"] = "A" + static_q = static_score_yaml(cfg) + cfg["quality_score_static"] = static_q + cfg["quality_score"] = static_q + cfg["quality_score_health"] = None + # YAML cfgs whose provider is OPENROUTER are also subject + # to health gating against their own /endpoints data — a + # hand-picked dead OR model is still dead. _enrich_health + # re-stamps health_gated for them on the next refresh tick. + cfg["health_gated"] = False + except Exception as e: + print(f"Warning: Failed to score global LLM configs: {e}") + return configs except Exception as e: print(f"Warning: Failed to load global LLM configs: {e}") @@ -117,7 +164,11 @@ def load_global_image_gen_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_image_generation_configs", []) + configs = data.get("global_image_generation_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global image generation configs: {e}") return [] @@ -132,7 +183,11 @@ def load_global_vision_llm_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_vision_llm_configs", []) + configs = data.get("global_vision_llm_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global vision LLM configs: {e}") return [] @@ -194,6 +249,9 @@ def load_openrouter_integration_settings() -> dict | None: """ Load OpenRouter integration settings from the YAML config. + Emits startup warnings for deprecated keys (``billing_tier``, + ``anonymous_enabled``) and seeds their replacements for back-compat. + Returns: dict with settings if present and enabled, None otherwise """ @@ -206,9 +264,40 @@ def load_openrouter_integration_settings() -> dict | None: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) settings = data.get("openrouter_integration") - if settings and settings.get("enabled"): - return settings - return None + if not settings or not settings.get("enabled"): + return None + + if "billing_tier" in settings: + print( + "Warning: openrouter_integration.billing_tier is deprecated; " + "tier is now derived per model from OpenRouter data " + "(':free' suffix or zero pricing). Remove this key." + ) + + if "anonymous_enabled" in settings: + print( + "Warning: openrouter_integration.anonymous_enabled is " + "deprecated; use anonymous_enabled_paid and/or " + "anonymous_enabled_free instead. Both new flags have been " + "seeded from the legacy value for back-compat." + ) + settings.setdefault( + "anonymous_enabled_paid", settings["anonymous_enabled"] + ) + settings.setdefault( + "anonymous_enabled_free", settings["anonymous_enabled"] + ) + + # Image generation + vision LLM emission are opt-in (issue L). + # OpenRouter's catalogue contains hundreds of image / vision + # capable models; auto-injecting all of them into every + # deployment would explode the model selector and surprise + # operators upgrading from prior versions. Default to False so + # admins must explicitly turn them on. + settings.setdefault("image_generation_enabled", False) + settings.setdefault("vision_enabled", False) + + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") return None @@ -217,9 +306,14 @@ def load_openrouter_integration_settings() -> dict | None: def initialize_openrouter_integration(): """ If enabled, fetch all OpenRouter models and append them to - config.GLOBAL_LLM_CONFIGS as dynamic premium entries. - Should be called BEFORE initialize_llm_router() so the router - correctly excludes premium models from Auto mode. + config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier`` + is derived per-model from OpenRouter's API signals (``:free`` suffix or + zero pricing), so free OpenRouter models correctly skip premium quota. + + Should be called BEFORE initialize_llm_router(). Dynamic entries are + tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used + by title-gen / sub-agent flows) remains scoped to curated YAML configs, + while user-facing Auto-mode thread pinning still considers them. """ settings = load_openrouter_integration_settings() if not settings: @@ -235,16 +329,70 @@ def initialize_openrouter_integration(): if new_configs: config.GLOBAL_LLM_CONFIGS.extend(new_configs) + free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free") + premium_count = sum( + 1 for c in new_configs if c.get("billing_tier") == "premium" + ) print( f"Info: OpenRouter integration added {len(new_configs)} models " - f"(billing_tier={settings.get('billing_tier', 'premium')})" + f"(free={free_count}, premium={premium_count})" ) else: print("Info: OpenRouter integration enabled but no models fetched") + + # Image generation + vision LLM emissions are opt-in (issue L). + # Both reuse the catalogue already cached by ``service.initialize`` + # so we don't make additional network calls here. + if settings.get("image_generation_enabled"): + try: + image_configs = service.get_image_generation_configs() + if image_configs: + config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs) + print( + f"Info: OpenRouter integration added {len(image_configs)} " + f"image-generation models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") + + if settings.get("vision_enabled"): + try: + vision_configs = service.get_vision_llm_configs() + if vision_configs: + config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) + print( + f"Info: OpenRouter integration added {len(vision_configs)} " + f"vision LLM models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") +def initialize_pricing_registration(): + """ + Teach LiteLLM the per-token cost of every deployment in + ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled + from the OpenRouter catalogue + any operator-declared YAML pricing). + + Must run AFTER ``initialize_openrouter_integration()`` so the + OpenRouter catalogue is populated and BEFORE the first LLM call so + ``response_cost`` is available in ``TokenTrackingCallback``. + + Failures are logged but never raised — startup must not be blocked + by a missing pricing entry; the worst-case is the model debits 0. + """ + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as e: + print(f"Warning: Failed to register LiteLLM pricing: {e}") + + def initialize_llm_router(): """ Initialize the LLM Router service for Auto mode. @@ -389,14 +537,54 @@ class Config: os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") ) - # Premium token quota settings - PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000")) + # Premium credit (micro-USD) quota settings. + # + # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy + # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are + # still honoured for one release as fall-back values — the prior + # $1-per-1M-tokens Stripe price means every existing value maps 1:1 + # to micros, so operators upgrading without changing their .env still + # get correct behaviour. A startup deprecation warning fires below if + # they're set. + PREMIUM_CREDIT_MICROS_LIMIT = int( + os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") + or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000") + ) STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") - STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")) + STRIPE_CREDIT_MICROS_PER_UNIT = int( + os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT") + or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000") + ) STRIPE_TOKEN_BUYING_ENABLED = ( os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" ) + # Safety ceiling on the per-call premium reservation. ``stream_new_chat`` + # estimates an upper-bound cost from ``litellm.get_model_info`` x the + # config's ``quota_reserve_tokens`` and clamps the result to this value + # so a misconfigured "$1000/M" model can't lock the user's whole balance + # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K + # reserve_tokens ≈ $0.36) with headroom. + QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000")) + + if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv( + "PREMIUM_CREDIT_MICROS_LIMIT" + ): + print( + "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to " + "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the " + "current Stripe price). The old key will be removed in a " + "future release." + ) + if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv( + "STRIPE_CREDIT_MICROS_PER_UNIT" + ): + print( + "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to " + "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). " + "The old key will be removed in a future release." + ) + # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" MULTI_AGENT_CHAT_ENABLED = ( @@ -412,6 +600,35 @@ class Config: # Default quota reserve tokens when not specified per-model QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000")) + # Per-image reservation (in micro-USD) used by ``billable_call`` for the + # ``POST /image-generations`` endpoint when the global config does not + # override it. $0.05 covers realistic worst-cases for current OpenAI / + # OpenRouter image-gen pricing. Bypassed entirely for free configs. + QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") + ) + + # Per-podcast reservation (in micro-USD). One agent LLM call generating + # a transcript, typically 5k-20k completion tokens. $0.20 covers a long + # premium-model run. Tune via env. + QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000") + ) + + # Per-video-presentation reservation (in micro-USD). Fan-out of N + # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``) + # plus refine retries; can produce many premium completions. $1.00 + # covers worst-case. Tune via env. + # + # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of + # 1_000_000. The override path in ``billable_call`` bypasses the + # per-call clamp in ``estimate_call_reserve_micros``, so this is the + # *actual* hold — raising it via env is fine but means a single video + # task can lock $1+ of credit. + QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000") + ) + # Abuse prevention: concurrent stream cap and CAPTCHA ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2")) ANON_CAPTCHA_REQUEST_THRESHOLD = int( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 9aca0f022..d92640c8d 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -19,6 +19,24 @@ # Structure matches NewLLMConfig: # - Model configuration (provider, model_name, api_key, etc.) # - Prompt configuration (system_instructions, citations_enabled) +# +# COST-BASED PREMIUM CREDITS: +# Each premium config bills the user's USD-credit balance based on the +# actual provider cost reported by LiteLLM. For models LiteLLM already +# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything. +# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment) +# or any model LiteLLM doesn't have in its built-in pricing table, declare +# per-token costs inline so they bill correctly: +# +# litellm_params: +# base_model: "my-custom-azure-deploy" +# # USD per token; e.g. 0.000003 == $3.00 per million input tokens +# input_cost_per_token: 0.000003 +# output_cost_per_token: 0.000015 +# +# OpenRouter dynamic models pull pricing automatically from OpenRouter's +# API — no inline declaration needed. Models without resolvable pricing +# debit $0 from the user's balance and log a WARNING. # Router Settings for Auto Mode # These settings control how the LiteLLM Router distributes requests across models @@ -245,31 +263,64 @@ global_llm_configs: # ============================================================================= # When enabled, dynamically fetches ALL available models from the OpenRouter API # and injects them as global configs. This gives premium users access to any model -# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota. +# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota, +# while free-tier OpenRouter models show up with a green Free badge and do NOT +# consume premium quota. # Models are fetched at startup and refreshed periodically in the background. # All calls go through LiteLLM with the openrouter/ prefix. openrouter_integration: enabled: false api_key: "sk-or-your-openrouter-api-key" - # billing_tier: "premium" or "free". Controls whether users need premium tokens. - billing_tier: "premium" - # anonymous_enabled: set true to also show OpenRouter models to no-login users - anonymous_enabled: false + + # Tier is derived PER MODEL from OpenRouter's own API signals: + # - id ends with ":free" -> billing_tier=free + # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free + # - otherwise -> billing_tier=premium + # No global billing_tier knob is honored; any legacy value emits a startup warning. + + # Anonymous access is split by tier so operators can expose only free + # models to no-login users without leaking paid inference. + anonymous_enabled_paid: false + anonymous_enabled_free: false + seo_enabled: false # quota_reserve_tokens: tokens reserved per call for quota enforcement quota_reserve_tokens: 4000 - # id_offset: starting negative ID for dynamically generated configs. - # Must not overlap with your static global_llm_configs IDs above. + # id_offset: base negative ID for dynamically generated configs. + # Model IDs are derived deterministically via BLAKE2b so they survive + # catalogue churn. Must not overlap with your static global_llm_configs IDs. id_offset: -10000 # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 - # rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing. - # OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled - # upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits). - # These values only matter if you set billing_tier to "free" (adding them to Auto mode). - # For premium-only models they are cosmetic. Set conservatively or match your account tier. + + # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router + # for per-deployment accounting when OR premium models participate in the + # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your + # real account limits live at https://openrouter.ai/settings/limits. rpm: 200 tpm: 1000000 + + # Rate limits for FREE OpenRouter models. Informational only: free OR + # models are intentionally kept OUT of the LiteLLM Router pool, because + # OpenRouter enforces free-tier limits globally per account (~20 RPM + + # 50-1000 daily requests across every ":free" model combined) — + # per-deployment router accounting can't represent a shared bucket + # correctly. Free OR models stay fully available in the model selector + # and for user-facing Auto thread pinning. + free_rpm: 20 + free_tpm: 100000 + + # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue + # contains hundreds of image- and vision-capable models; turning these on + # injects them into the global Image-Generation / Vision-LLM model + # selectors alongside any static configs. Tier (free/premium) is derived + # per model the same way it is for chat (`:free` suffix or zero pricing). + # When a user picks a premium image/vision model the call debits the + # shared $5 USD-cost-based premium credit pool — so leaving these off + # avoids surprise quota burn on existing deployments. Default: false. + image_generation_enabled: false + vision_enabled: false + litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 91d19fb4f..9fc27fb1f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,6 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) + # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # config id. NULL means no pin; Auto will resolve on the next turn. + # Single-writer invariant: only app.services.auto_model_pin_service sets + # or clears this column (plus bulk clears when a search space's + # agent_llm_id changes). Unindexed: all reads are by primary key. + pinned_llm_config_id = Column(Integer, nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") @@ -669,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin): __tablename__ = "new_chat_messages" + # Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL. + # Mirrors alembic migration 141. Lets the streaming agent and the + # legacy frontend appendMessage call coexist idempotently — the second + # writer trips the unique and recovers without creating a duplicate row. + # Partial so legacy NULL turn_id rows and clone/snapshot inserts in + # app/services/public_chat_service.py (which omit turn_id) are unaffected. + __table_args__ = ( + Index( + "uq_new_chat_messages_thread_turn_role", + "thread_id", + "turn_id", + "role", + unique=True, + postgresql_where=text("turn_id IS NOT NULL"), + ), + ) + role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False) # Content stored as JSONB to support rich content (text, tool calls, etc.) content = Column(JSONB, nullable=False) @@ -722,9 +745,26 @@ class TokenUsage(BaseModel, TimestampMixin): __tablename__ = "token_usage" + # Partial unique index on (message_id) where message_id IS NOT NULL. + # Mirrors alembic migration 142. Lets the streaming agent's + # ``finalize_assistant_turn`` and the legacy frontend ``append_message`` + # recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without + # racing on a SELECT-then-INSERT window. Partial so non-chat usage rows + # (indexing, image generation, podcasts) — which keep ``message_id`` NULL + # because there is no per-message anchor — are unaffected. + __table_args__ = ( + Index( + "uq_token_usage_message_id", + "message_id", + unique=True, + postgresql_where=text("message_id IS NOT NULL"), + ), + ) + prompt_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0) + cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0") model_breakdown = Column(JSONB, nullable=True) call_details = Column(JSONB, nullable=True) @@ -1787,7 +1827,15 @@ class PagePurchase(Base, TimestampMixin): class PremiumTokenPurchase(Base, TimestampMixin): - """Tracks Stripe checkout sessions used to grant additional premium token credits.""" + """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units). + + Note: the table name is preserved (``premium_token_purchases``) for + operational continuity even though the unit is now USD micro-credits + instead of raw tokens. The ``credit_micros_granted`` column replaced + the legacy ``tokens_granted`` in migration 140; the stored values + were not transformed because the prior $1 = 1M tokens Stripe price + makes the unit conversion 1:1 numerically. + """ __tablename__ = "premium_token_purchases" __allow_unmapped__ = True @@ -1804,7 +1852,7 @@ class PremiumTokenPurchase(Base, TimestampMixin): ) stripe_payment_intent_id = Column(String(255), nullable=True, index=True) quantity = Column(Integer, nullable=False) - tokens_granted = Column(BigInteger, nullable=False) + credit_micros_granted = Column(BigInteger, nullable=False) amount_total = Column(Integer, nullable=True) currency = Column(String(10), nullable=True) status = Column( @@ -2103,16 +2151,16 @@ if config.AUTH_TYPE == "GOOGLE": ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) @@ -2235,16 +2283,16 @@ else: ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py index 4bb38b7b0..d45bd780c 100644 --- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py +++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py @@ -68,12 +68,25 @@ class EtlPipelineService: etl_service="VISION_LLM", content_type="image", ) - except Exception: - logging.warning( - "Vision LLM failed for %s, falling back to document parser", - request.filename, - exc_info=True, - ) + except Exception as exc: + # Special-case quota exhaustion so we log a clearer message + # — the vision LLM didn't "fail", the user just ran out of + # premium credit. Falling through to the document parser + # is a graceful degradation: OCR/Unstructured still + # extracts text from the image without burning credit. + from app.services.billable_calls import QuotaInsufficientError + + if isinstance(exc, QuotaInsufficientError): + logging.info( + "Vision LLM quota exhausted for %s; falling back to document parser", + request.filename, + ) + else: + logging.warning( + "Vision LLM failed for %s, falling back to document parser", + request.filename, + exc_info=True, + ) else: logging.info( "No vision LLM provided, falling back to document parser for %s", diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index 5732a8dfb..99388af66 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags +from app.config import config from app.db import User from app.users import current_active_user @@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel): enable_otel: bool + enable_desktop_local_filesystem: bool + @classmethod def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead: # asdict() avoids missing-field bugs when AgentFeatureFlags grows. - return cls(**asdict(flags)) + return cls( + **asdict(flags), + enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM, + ) @router.get("/agent/flags", response_model=AgentFeatureFlagsRead) diff --git a/surfsense_backend/app/routes/composio_routes.py b/surfsense_backend/app/routes/composio_routes.py index 4bf360365..7bc2addf8 100644 --- a/surfsense_backend/app/routes/composio_routes.py +++ b/surfsense_backend/app/routes/composio_routes.py @@ -649,13 +649,9 @@ async def list_composio_drive_folders( """ List folders AND files in user's Google Drive via Composio. - Uses the same GoogleDriveClient / list_folder_contents path as the native - connector, with Composio-sourced credentials. This means auth errors - propagate identically (Google returns 401 → exception → auth_expired flag). + Uses Composio's Google Drive tool execution path so managed OAuth tokens + do not need to be exposed through connected account state. """ - from app.connectors.google_drive import GoogleDriveClient, list_folder_contents - from app.utils.google_credentials import build_composio_credentials - if not ComposioService.is_enabled(): raise HTTPException( status_code=503, @@ -689,10 +685,37 @@ async def list_composio_drive_folders( detail="Composio connected account not found. Please reconnect the connector.", ) - credentials = build_composio_credentials(composio_connected_account_id) - drive_client = GoogleDriveClient(session, connector_id, credentials=credentials) + service = ComposioService() + entity_id = f"surfsense_{user.id}" + items = [] + page_token = None + error = None - items, error = await list_folder_contents(drive_client, parent_id=parent_id) + while True: + page_items, next_token, page_error = await service.get_drive_files( + connected_account_id=composio_connected_account_id, + entity_id=entity_id, + folder_id=parent_id, + page_token=page_token, + page_size=100, + ) + if page_error: + error = page_error + break + + items.extend(page_items) + if not next_token: + break + page_token = next_token + + for item in items: + item["isFolder"] = ( + item.get("mimeType") == "application/vnd.google-apps.folder" + ) + + items.sort( + key=lambda item: (not item["isFolder"], item.get("name", "").lower()) + ) if error: error_lower = error.lower() diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index f558481cf..f1ca3b6bf 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -745,6 +745,51 @@ async def search_document_titles( ) from e +@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead) +async def get_document_by_virtual_path( + search_space_id: int, + virtual_path: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Resolve a knowledge-base document id by exact virtual path.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + + result = await session.execute( + select( + Document.id, + Document.title, + Document.document_type, + ).filter( + Document.search_space_id == search_space_id, + Document.document_metadata["virtual_path"].as_string() == virtual_path, + ) + ) + row = result.first() + if row is None: + raise HTTPException(status_code=404, detail="Document not found") + + return DocumentTitleRead( + id=row.id, + title=row.title, + document_type=row.document_type, + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to resolve document by virtual path: {e!s}", + ) from e + + @router.get("/documents/status", response_model=DocumentStatusBatchResponse) async def get_documents_status( search_space_id: int, diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 97a3559b9..018234ad5 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -36,11 +36,17 @@ from app.schemas import ( ImageGenerationListRead, ImageGenerationRead, ) +from app.services.billable_calls import ( + DEFAULT_IMAGE_RESERVE_MICROS, + QuotaInsufficientError, + billable_call, +) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + """Resolve the LiteLLM provider prefix used in model strings.""" + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: """Build a litellm model string from provider + model_name.""" - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" + + +async def _resolve_billing_for_image_gen( + session: AsyncSession, + config_id: int | None, + search_space: SearchSpace, +) -> tuple[str, str, int]: + """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request. + + The resolution mirrors ``_execute_image_generation``'s lookup tree but + only extracts the fields needed for billing — we do this *before* + ``billable_call`` so the reservation is correctly sized for the + config that will actually run, and so we don't open an + ``ImageGeneration`` row for a request that's about to 402. + + User-owned (positive ID) BYOK configs are always free — they cost + the user nothing on our side. Auto mode currently treats as free + because the underlying router can dispatch to either premium or + free YAML configs and we don't surface the resolved deployment up + here yet. Bringing Auto under premium billing would require + threading the chosen deployment back from ``ImageGenRouterService``. + """ + resolved_id = config_id + if resolved_id is None: + resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + + if is_image_gen_auto_mode(resolved_id): + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + + if resolved_id < 0: + cfg = _get_global_image_gen_config(resolved_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + base_model = _build_model_string( + cfg.get("provider", ""), + cfg.get("model_name", ""), + cfg.get("custom_provider"), + ) + reserve_micros = int( + cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + ) + return (billing_tier, base_model, reserve_micros) + + # Positive ID = user-owned BYOK image-gen config — always free. + return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) async def _execute_image_generation( @@ -138,12 +192,18 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - model_string = _build_model_string( - cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -165,12 +225,18 @@ async def _execute_image_generation( if not db_cfg: raise ValueError(f"Image generation config {config_id} not found") - model_string = _build_model_string( - db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: @@ -225,10 +291,15 @@ async def get_global_image_gen_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode currently treated as free until per-deployment + # billing-tier surfacing lands (see _resolve_billing_for_image_gen). + "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -241,6 +312,12 @@ async def get_global_image_gen_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_micros": cfg.get("quota_reserve_micros"), } ) @@ -454,7 +531,26 @@ async def create_image_generation( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create and execute an image generation request.""" + """Create and execute an image generation request. + + Premium configs are gated by the user's shared premium credit pool. + The flow is: + + 1. Permission check + load the search space (cheap, no provider call). + 2. Resolve which config will run so we know its billing tier and the + worst-case reservation size *before* opening any DB rows. + 3. Wrap the entire ImageGeneration row insert + provider call in + ``billable_call``. If quota is denied, ``billable_call`` raises + ``QuotaInsufficientError`` *before* we flush a row, which we + translate to HTTP 402 (no orphaned rows on the user's account, + no inserted error rows for "you ran out of credit"). + 4. On success, the actual ``response_cost`` flows through the + LiteLLM callback into the accumulator, and ``billable_call`` + finalizes the debit at exit. Inner ``try/except`` still catches + provider errors and stores them on ``error_message`` (HTTP 200 + with ``error_message`` set is preserved for failed-but-not-quota + scenarios — clients already know how to surface those). + """ try: await check_permission( session, @@ -471,33 +567,70 @@ async def create_image_generation( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - db_image_gen = ImageGeneration( - prompt=data.prompt, - model=data.model, - n=data.n, - quality=data.quality, - size=data.size, - style=data.style, - response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, - search_space_id=data.search_space_id, - created_by_id=user.id, + billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( + session, data.image_generation_config_id, search_space ) - session.add(db_image_gen) - await session.flush() - try: - await _execute_image_generation(session, db_image_gen, search_space) - except Exception as e: - logger.exception("Image generation call failed") - db_image_gen.error_message = str(e) + # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError + # propagates to the outer ``except QuotaInsufficientError`` handler + # below as HTTP 402 — it is intentionally NOT swallowed into + # ``error_message`` because that would (1) imply a successful row + # exists when none does, and (2) return HTTP 200 to a client + # whose request was actively *denied* (issue K). + async with billable_call( + user_id=search_space.user_id, + search_space_id=data.search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=reserve_micros, + usage_type="image_generation", + call_details={"model": base_model, "prompt": data.prompt[:100]}, + ): + db_image_gen = ImageGeneration( + prompt=data.prompt, + model=data.model, + n=data.n, + quality=data.quality, + size=data.size, + style=data.style, + response_format=data.response_format, + image_generation_config_id=data.image_generation_config_id, + search_space_id=data.search_space_id, + created_by_id=user.id, + ) + session.add(db_image_gen) + await session.flush() - await session.commit() - await session.refresh(db_image_gen) - return db_image_gen + try: + await _execute_image_generation(session, db_image_gen, search_space) + except Exception as e: + logger.exception("Image generation call failed") + db_image_gen.error_message = str(e) + + await session.commit() + await session.refresh(db_image_gen) + return db_image_gen except HTTPException: raise + except QuotaInsufficientError as exc: + # The user's premium credit pool is empty. No DB row is created + # because ``billable_call`` denies before yielding (issue K). + await session.rollback() + raise HTTPException( + status_code=402, + detail={ + "error_code": "premium_quota_exhausted", + "usage_type": exc.usage_type, + "used_micros": exc.used_micros, + "limit_micros": exc.limit_micros, + "remaining_micros": exc.remaining_micros, + "message": ( + "Out of premium credits for image generation. " + "Purchase additional credits or switch to a free model." + ), + }, + ) from exc except SQLAlchemyError: await session.rollback() raise HTTPException( diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index c95553fce..ad96654f5 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -15,9 +15,10 @@ import json import logging from datetime import UTC, datetime -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from sqlalchemy import func, or_ +from sqlalchemy import func, or_, text as sa_text +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -29,6 +30,12 @@ from app.agents.new_chat.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, + manager, + request_cancel, +) from app.config import config from app.db import ( ChatComment, @@ -38,12 +45,14 @@ from app.db import ( NewChatThread, Permission, SearchSpace, + TokenUsage, User, get_async_session, shielded_async_session, ) from app.schemas.new_chat import ( AgentToolInfo, + CancelActiveTurnResponse, LocalFilesystemMountPayload, NewChatMessageRead, NewChatRequest, @@ -60,10 +69,11 @@ from app.schemas.new_chat import ( ThreadListItem, ThreadListResponse, TokenUsageSummary, + TurnStatusResponse, ) -from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user +from app.utils.perf import get_perf_logger from app.utils.rbac import check_permission from app.utils.user_message_multimodal import ( split_langchain_human_content, @@ -71,7 +81,11 @@ from app.utils.user_message_multimodal import ( ) _logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() _background_tasks: set[asyncio.Task] = set() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 router = APIRouter() @@ -137,6 +151,72 @@ def _resolve_filesystem_selection( ) +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + """Bounded exponential delay for TURN_CANCELLING retry hints.""" + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def _build_turn_status_payload(thread_id: int) -> dict[str, object]: + lock = manager.lock_for(str(thread_id)) + if not lock.locked(): + return {"status": "idle"} + + if is_cancel_requested(str(thread_id)): + cancel_state = get_cancel_state(str(thread_id)) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms + return { + "status": "cancelling", + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + } + + return {"status": "busy"} + + +def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None: + response.headers["retry-after-ms"] = str(retry_after_ms) + response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000)) + + +def _raise_if_thread_busy_for_start(thread_id: int) -> None: + status_payload = _build_turn_status_payload(thread_id) + status = status_payload["status"] + if status == "idle": + return + if status == "cancelling": + retry_after_ms = int(status_payload.get("retry_after_ms") or 0) + detail = { + "errorCode": "TURN_CANCELLING", + "message": "A previous response is still stopping. Please try again in a moment.", + "retry_after_ms": retry_after_ms if retry_after_ms > 0 else None, + "retry_after_at": status_payload.get("retry_after_at"), + } + headers = ( + { + "retry-after-ms": str(retry_after_ms), + "Retry-After": str(max(1, (retry_after_ms + 999) // 1000)), + } + if retry_after_ms > 0 + else None + ) + raise HTTPException(status_code=409, detail=detail, headers=headers) + + raise HTTPException( + status_code=409, + detail={ + "errorCode": "THREAD_BUSY", + "message": "Another response is still finishing for this thread. Please try again in a moment.", + }, + ) + + def _find_pre_turn_checkpoint_id( checkpoint_tuples: list, *, @@ -1210,6 +1290,24 @@ async def append_message( user: User = Depends(current_active_user), ): """ + .. deprecated:: 2026-05 + Replaced by the **SSE-based message ID handshake**. The streaming + generator (`stream_new_chat` / `stream_resume_chat`) now persists + both the user and assistant rows server-side via + ``persist_user_turn`` / ``persist_assistant_shell`` and emits + ``data-user-message-id`` / ``data-assistant-message-id`` SSE events + so the frontend can rename its optimistic IDs in real time. The + new FE bundle no longer calls this route. + + This handler is retained as a **silent no-op for legacy / cached + FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING`` + pattern means a stale bundle hitting this route after the SSE + handshake already wrote the row simply returns the existing row + (200 OK) without raising or duplicating data. After a 2-week soak + (target: ``[persist_user_turn] outcome=race_recovered`` rate ~0) + this entire route — and the FE ``appendMessage`` function — is + earmarked for removal. + Append a message to a thread. This is used by ThreadHistoryAdapter.append() to persist messages. @@ -1220,6 +1318,22 @@ async def append_message( Requires CHATS_UPDATE permission. """ try: + # Capture ``user.id`` as a primitive UUID up front. The + # ``current_active_user`` dependency hands us a ``User`` ORM + # row bound to ``session``; if the outer ``except + # IntegrityError`` block below ever fires (an unexpected + # constraint like a foreign key violation — the common + # ``(thread_id, turn_id, role)`` race is now handled silently + # by ``ON CONFLICT DO NOTHING`` so it never raises) it calls + # ``session.rollback()``, which expires every attached ORM + # row including this user. Any later ``user.id`` access would + # then trigger a lazy PK reload — which on async SQLAlchemy + # fails with ``MissingGreenlet`` because the reload happens + # outside the awaitable greenlet boundary. Reading ``id`` + # once here pins the value as a plain UUID so all downstream + # uses (TokenUsage insert, response build) are immune. + user_uuid = user.id + # Parse raw body - extract only role and content, ignoring extra fields raw_body = await request.json() role = raw_body.get("role") @@ -1274,37 +1388,166 @@ async def append_message( else None ) - db_message = NewChatMessage( - thread_id=thread_id, - role=message_role, - content=content, - author_id=user.id, - turn_id=turn_id_value, - ) - session.add(db_message) - - # Update thread's updated_at timestamp + # Update thread's updated_at timestamp (always — both insert + # and recovery paths represent thread activity). thread.updated_at = datetime.now(UTC) - # flush assigns the PK/defaults without a round-trip SELECT - await session.flush() + # Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING`` + # keyed on the ``(thread_id, turn_id, role)`` partial unique + # index from migration 141 (``WHERE turn_id IS NOT NULL``). + # + # Why ON CONFLICT instead of ``session.add() + flush() + except + # IntegrityError``: + # 1. The conflict between this legacy FE ``appendMessage`` + # round-trip and the server-side + # ``finalize_assistant_turn`` writer is a NORMAL, + # *expected* race — every assistant turn fires it. Using + # catch-and-recover means asyncpg raises + # ``UniqueViolationError`` -> SQLAlchemy wraps it as + # ``IntegrityError`` -> our handler catches and recovers. + # Functionally fine, but every ``raise`` event lights up + # VS Code's debugger (debugpy's ``justMyCode=false`` mode + # loses track of the catch frame across SQLAlchemy's + # async greenlet boundary, so even ``Raised Exceptions`` + # being unchecked doesn't reliably suppress the pause). + # ON CONFLICT pushes the conflict resolution into Postgres + # where no Python exception is constructed at all. + # 2. No ``session.rollback()`` -> no expiring of attached + # ORM rows -> no risk of ``MissingGreenlet`` from + # lazy-loading expired user/thread state later in the + # handler. + # 3. Cleaner production logs (no SQLAlchemy ``IntegrityError`` + # tracebacks emitted by uvicorn's logger between the + # ``raise`` and our ``except``). + # + # When ``turn_id_value`` is ``None`` the partial index doesn't + # apply and the INSERT proceeds normally. Other constraint + # violations (FK, NOT NULL, etc.) still raise ``IntegrityError`` + # and are caught by the outer ``except IntegrityError`` block + # to preserve the legacy 400 behavior. + # + # Note on ``content``: when we recover the existing row, we + # intentionally discard the FE's ``content`` payload from + # ``raw_body`` and return the row's existing ``content``. The + # streaming task is now the *authoritative writer* for + # assistant ``ContentPart[]`` shape (mid-stream + # ``AssistantContentBuilder`` -> ``finalize_assistant_turn``) + # so the FE's later ``appendMessage`` is just a stale snapshot + # of the same data — keeping the server-built rich content + # (with full tool-call args / argsText / langchainToolCallId) + # is correct, not lossy. + insert_stmt = ( + pg_insert(NewChatMessage) + .values( + thread_id=thread_id, + role=message_role, + content=content, + author_id=user_uuid, + turn_id=turn_id_value, + ) + .on_conflict_do_nothing( + index_elements=["thread_id", "turn_id", "role"], + index_where=sa_text("turn_id IS NOT NULL"), + ) + .returning(NewChatMessage.id) + ) + inserted_id = (await session.execute(insert_stmt)).scalar() - # Persist token usage if provided (for assistant messages) + if inserted_id is None: + # Conflict on partial unique index — server-side stream + # already wrote this row. Look it up and reuse it. + if turn_id_value is None: + # Defensive: ON CONFLICT only fires for ``turn_id IS + # NOT NULL`` rows, so this branch should be + # unreachable. Preserve the legacy 400 just in case + # Postgres ever surprises us. + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + lookup = await session.execute( + select(NewChatMessage).filter( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id_value, + NewChatMessage.role == message_role, + ) + ) + existing_message = lookup.scalars().first() + if existing_message is None: + # Conflict reported but the row vanished between + # INSERT and SELECT — extremely unlikely (would + # require a concurrent DELETE within the same + # transaction visibility), but preserve safe + # behavior. + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + db_message = existing_message + # Perf signal: counts how often the legacy FE round-trip + # races the server-side ``finalize_assistant_turn``. A + # rising rate after the rework is OK (it's exactly the + # ghost-thread fix's recovery path firing); a sudden drop + # to zero would mean the FE isn't posting appendMessage + # at all (different bug). + _perf_log.info( + "[append_message] outcome=recovered_via_unique_index " + "thread_id=%s turn_id=%s role=%s message_id=%s", + thread_id, + turn_id_value, + message_role.value, + db_message.id, + ) + else: + # INSERT succeeded — load the full ORM row so the + # response can include server-side-defaulted columns + # (``created_at``, etc.) and the relationship surface + # stays consistent with the recovery path. + inserted_row = await session.get(NewChatMessage, inserted_id) + if inserted_row is None: + # Should be impossible: we just inserted it in this + # same transaction. Fail loud if it happens. + raise HTTPException( + status_code=500, + detail="Inserted message could not be loaded.", + ) from None + db_message = inserted_row + + # Persist token usage if provided (for assistant messages). + # ``cost_micros`` is the provider USD cost reported by LiteLLM, + # forwarded by the FE through the appendMessage round-trip so + # the historical TokenUsage row matches the credit debit applied + # at finalize time. + # + # De-dup: ``finalize_assistant_turn`` may also race to write a + # token_usage row for this same ``message_id`` (cross-session, + # cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed + # on the ``uq_token_usage_message_id`` partial unique index + # (migration 142). The loser silently drops its insert; exactly + # one row results regardless of which writer commits first. token_usage_data = raw_body.get("token_usage") if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: - await record_token_usage( - session, - usage_type="chat", - search_space_id=thread.search_space_id, - user_id=user.id, - prompt_tokens=token_usage_data.get("prompt_tokens", 0), - completion_tokens=token_usage_data.get("completion_tokens", 0), - total_tokens=token_usage_data.get("total_tokens", 0), - model_breakdown=token_usage_data.get("usage"), - call_details=token_usage_data.get("call_details"), - thread_id=thread_id, - message_id=db_message.id, + insert_stmt = ( + pg_insert(TokenUsage) + .values( + usage_type="chat", + prompt_tokens=token_usage_data.get("prompt_tokens", 0), + completion_tokens=token_usage_data.get("completion_tokens", 0), + total_tokens=token_usage_data.get("total_tokens", 0), + cost_micros=token_usage_data.get("cost_micros", 0), + model_breakdown=token_usage_data.get("usage"), + call_details=token_usage_data.get("call_details"), + thread_id=thread_id, + message_id=db_message.id, + search_space_id=thread.search_space_id, + user_id=user_uuid, + ) + .on_conflict_do_nothing( + index_elements=["message_id"], + index_where=sa_text("message_id IS NOT NULL"), + ) ) + await session.execute(insert_stmt) await session.commit() @@ -1324,6 +1567,9 @@ async def append_message( except HTTPException: raise except IntegrityError: + # Any IntegrityError that escaped the inline handler above + # comes from a *different* constraint (foreign key, etc.) — + # preserve the legacy 400 path. await session.rollback() raise HTTPException( status_code=400, @@ -1476,6 +1722,7 @@ async def handle_new_chat( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(request.chat_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -1516,6 +1763,12 @@ async def handle_new_chat( else None ) + mentioned_documents_payload = ( + [doc.model_dump() for doc in request.mentioned_documents] + if request.mentioned_documents + else None + ) + return StreamingResponse( stream_new_chat( user_query=request.user_query, @@ -1525,6 +1778,7 @@ async def handle_new_chat( llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + mentioned_documents=mentioned_documents_payload, needs_history_bootstrap=thread.needs_history_bootstrap, thread_visibility=thread.visibility, current_user_display_name=user.display_name or "A team member", @@ -1550,6 +1804,93 @@ async def handle_new_chat( ) from None +@router.post( + "/threads/{thread_id}/cancel-active-turn", + response_model=CancelActiveTurnResponse, +) +async def cancel_active_turn( + thread_id: int, + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Signal cancellation for the currently running turn on ``thread_id``.""" + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + if status_payload["status"] == "idle": + return CancelActiveTurnResponse( + status="idle", + error_code="NO_ACTIVE_TURN", + ) + + request_cancel(str(thread_id)) + response.status_code = 202 + updated_payload = _build_turn_status_payload(thread_id) + retry_after_ms = int(updated_payload.get("retry_after_ms") or 0) + retry_after_at = ( + int(updated_payload["retry_after_at"]) + if "retry_after_at" in updated_payload + else None + ) + if retry_after_ms > 0: + _set_retry_after_headers(response, retry_after_ms) + return CancelActiveTurnResponse( + status="cancelling", + error_code="TURN_CANCELLING", + retry_after_ms=retry_after_ms if retry_after_ms > 0 else None, + retry_after_at=retry_after_at, + ) + + +@router.get( + "/threads/{thread_id}/turn-status", + response_model=TurnStatusResponse, +) +async def get_turn_status( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + return TurnStatusResponse( + status=status_payload["status"], # type: ignore[arg-type] + active_turn_id=None, + retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type] + retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type] + ) + + # ============================================================================= # Chat Regeneration Endpoint (Edit/Reload) # ============================================================================= @@ -1605,6 +1946,7 @@ async def regenerate_response( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -1907,6 +2249,11 @@ async def regenerate_response( "data": revert_results, } yield f"data: {json.dumps(envelope, default=str)}\n\n".encode() + mentioned_documents_payload = ( + [doc.model_dump() for doc in request.mentioned_documents] + if request.mentioned_documents + else None + ) try: async for chunk in stream_new_chat( user_query=str(user_query_to_use), @@ -1916,6 +2263,7 @@ async def regenerate_response( llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + mentioned_documents=mentioned_documents_payload, checkpoint_id=target_checkpoint_id, needs_history_bootstrap=thread.needs_history_bootstrap, thread_visibility=thread.visibility, @@ -1924,6 +2272,7 @@ async def regenerate_response( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=regenerate_image_urls or None, + flow="regenerate", ): yield chunk streaming_completed = True @@ -2011,6 +2360,7 @@ async def resume_chat( ) await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 20779a309..e090a1a7c 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -29,6 +29,7 @@ from app.schemas import ( NewLLMConfigUpdate, ) from app.services.llm_service import validate_llm_config +from app.services.provider_capabilities import derive_supports_image_input from app.users import current_active_user from app.utils.rbac import check_permission @@ -36,6 +37,39 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: + """Augment a BYOK chat config row with the derived ``supports_image_input``. + + There is no DB column for ``supports_image_input`` — the value is + resolved at the API boundary from LiteLLM's authoritative model map + (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps + the response shape consistent across list / detail / create / update + endpoints without having to remember to set the field at every call + site. + """ + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + supports_image_input = derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ) + # ``model_validate`` runs the Pydantic conversion using the ORM + # attribute access path enabled by ``ConfigDict(from_attributes=True)``, + # then we layer the derived field on. ``model_copy(update=...)`` keeps + # the surface immutable from the caller's perspective. + base_read = NewLLMConfigRead.model_validate(config) + return base_read.model_copy(update={"supports_image_input": supports_image_input}) + + # ============================================================================= # Global Configs Routes # ============================================================================= @@ -84,11 +118,41 @@ async def get_global_new_llm_configs( "seo_title": None, "seo_description": None, "quota_reserve_tokens": None, + # Auto routes across the configured pool, which usually + # includes at least one vision-capable deployment, so + # treat Auto as image-capable. The router itself will + # still pick a vision-capable deployment for messages + # carrying image_url blocks (LiteLLM Router falls back + # on ``404`` per its ``allowed_fails`` policy). + "supports_image_input": True, } ) # Add individual global configs for cfg in global_configs: + # Capability resolution: explicit value (YAML override or OR + # `_supports_image_input(model)` payload baked in by the + # OpenRouter integration service) wins. Fall back to the + # LiteLLM-driven helper which default-allows on unknown so + # we don't hide vision-capable models that happen to lack a + # YAML annotation. The streaming task safety net is the + # only place a False ever blocks. + if "supports_image_input" in cfg: + supports_image_input = bool(cfg.get("supports_image_input")) + else: + cfg_litellm_params = cfg.get("litellm_params") or {} + cfg_base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + supports_image_input = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=cfg_base_model, + custom_provider=cfg.get("custom_provider"), + ) + safe_config = { "id": cfg.get("id"), "name": cfg.get("name"), @@ -113,6 +177,7 @@ async def get_global_new_llm_configs( "seo_title": cfg.get("seo_title"), "seo_description": cfg.get("seo_description"), "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "supports_image_input": supports_image_input, } safe_configs.append(safe_config) @@ -171,7 +236,7 @@ async def create_new_llm_config( await session.commit() await session.refresh(db_config) - return db_config + return _serialize_byok_config(db_config) except HTTPException: raise @@ -213,7 +278,7 @@ async def list_new_llm_configs( .limit(limit) ) - return result.scalars().all() + return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] except HTTPException: raise @@ -268,7 +333,7 @@ async def get_new_llm_config( "You don't have permission to view LLM configurations in this search space", ) - return config + return _serialize_byok_config(config) except HTTPException: raise @@ -360,7 +425,7 @@ async def update_new_llm_config( await session.commit() await session.refresh(config) - return config + return _serialize_byok_config(config) except HTTPException: raise diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 828137518..5ecfb1814 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage from pydantic import BaseModel as PydanticBaseModel -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem from app.config import config from app.db import ( ImageGenerationConfig, + NewChatThread, NewLLMConfig, Permission, SearchSpace, @@ -593,6 +594,7 @@ async def _get_image_gen_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -609,6 +611,7 @@ async def _get_image_gen_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None @@ -651,6 +654,7 @@ async def _get_vision_llm_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -667,6 +671,7 @@ async def _get_vision_llm_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None @@ -790,9 +795,27 @@ async def update_llm_preferences( # Update preferences update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id for key, value in update_data.items(): setattr(search_space, key, value) + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values(pinned_llm_config_id=None) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + await session.commit() await session.refresh(search_space) diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index cfdd4b52a..aed74ec8d 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase( metadata = _get_metadata(checkout_session) user_id = metadata.get("user_id") quantity = int(metadata.get("quantity", "0")) - tokens_per_unit = int(metadata.get("tokens_per_unit", "0")) + # Read the new metadata key first, fall back to the legacy one so + # in-flight checkout sessions created before the cost-credits + # release still fulfil correctly (the unit is numerically the + # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens). + credit_micros_per_unit = int( + metadata.get("credit_micros_per_unit") + or metadata.get("tokens_per_unit", "0") + ) - if not user_id or quantity <= 0 or tokens_per_unit <= 0: + if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: logger.error( "Skipping token fulfillment for session %s: incomplete metadata %s", checkout_session_id, @@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase( getattr(checkout_session, "payment_intent", None) ), quantity=quantity, - tokens_granted=quantity * tokens_per_unit, + credit_micros_granted=quantity * credit_micros_per_unit, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase( purchase.stripe_payment_intent_id = _normalize_optional_string( getattr(checkout_session, "payment_intent", None) ) - user.premium_tokens_limit = ( - max(user.premium_tokens_used, user.premium_tokens_limit) - + purchase.tokens_granted + # Top up the user's credit balance by the granted micro-USD amount. + # ``max(used, limit)`` clamps the case where the legacy code wrote a + # used value above the limit (e.g. underbilling rounding) so adding + # ``credit_micros_granted`` always lifts the limit by the full pack + # size rather than disappearing into past overuse. + user.premium_credit_micros_limit = ( + max(user.premium_credit_micros_used, user.premium_credit_micros_limit) + + purchase.credit_micros_granted ) await db_session.commit() @@ -532,12 +544,18 @@ async def create_token_checkout_session( user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), ): - """Create a Stripe Checkout Session for buying premium token packs.""" + """Create a Stripe Checkout Session for buying premium credit packs. + + Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of + credit (default 1_000_000 = $1.00). The user's balance is debited + at the actual provider cost reported by LiteLLM at finalize time, + so $1 of credit always buys $1 worth of provider usage at cost. + """ _ensure_token_buying_enabled() stripe_client = get_stripe_client() price_id = _get_required_token_price_id() success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) - tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT + credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( @@ -556,8 +574,8 @@ async def create_token_checkout_session( "metadata": { "user_id": str(user.id), "quantity": str(body.quantity), - "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT), - "purchase_type": "premium_tokens", + "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), + "purchase_type": "premium_credit", }, } ) @@ -583,7 +601,7 @@ async def create_token_checkout_session( getattr(checkout_session, "payment_intent", None) ), quantity=body.quantity, - tokens_granted=tokens_granted, + credit_micros_granted=credit_micros_granted, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -598,14 +616,19 @@ async def create_token_checkout_session( async def get_token_status( user: User = Depends(current_active_user), ): - """Return token-buying availability and current premium quota for frontend.""" - used = user.premium_tokens_used - limit = user.premium_tokens_limit + """Return token-buying availability and current premium credit quota for frontend. + + Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M + when displaying. The route name is preserved for back-compat with + pinned client deployments. + """ + used = user.premium_credit_micros_used + limit = user.premium_credit_micros_limit return TokenStripeStatusResponse( token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, - premium_tokens_used=used, - premium_tokens_limit=limit, - premium_tokens_remaining=max(0, limit - used), + premium_credit_micros_used=used, + premium_credit_micros_limit=limit, + premium_credit_micros_remaining=max(0, limit - used), ) diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 315c7c9fe..e4f08f604 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -82,10 +82,15 @@ async def get_global_vision_llm_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode treated as free until per-deployment billing-tier + # surfacing lands; see ``get_vision_llm`` for parity. + "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -98,6 +103,14 @@ async def get_global_vision_llm_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "input_cost_per_token": cfg.get("input_cost_per_token"), + "output_cost_per_token": cfg.get("output_cost_per_token"), } ) diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 69f534e20..4262b2b3f 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel): Schema for reading global image generation configs from YAML. Global configs have negative IDs. API key is hidden. ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. """ id: int = Field( @@ -231,3 +237,24 @@ class GlobalImageGenConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_micros: int | None = Field( + default=None, + description=( + "Optional override for the reservation amount (in micro-USD) used when " + "this image generation is premium. Falls back to " + "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." + ), + ) diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index cfb4b8b37..95d183433 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel): prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 + cost_micros: int = 0 model_breakdown: dict | None = None model_config = ConfigDict(from_attributes=True) @@ -199,6 +200,21 @@ class NewChatUserImagePart(BaseModel): return to_data_url(self.media_type, self.data) +class MentionedDocumentInfo(BaseModel): + """Display metadata for a single ``@``-mentioned document. + + The full triple ``{id, title, document_type}`` is forwarded by the + frontend mention chip so the server can embed it in the persisted + user message ``ContentPart[]`` (single ``mentioned-documents`` part). + The history loader then renders the chips on reload without an extra + fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape. + """ + + id: int + title: str = Field(..., min_length=1, max_length=500) + document_type: str = Field(..., min_length=1, max_length=100) + + class NewChatRequest(BaseModel): """Request schema for the deep agent chat endpoint.""" @@ -212,6 +228,17 @@ class NewChatRequest(BaseModel): mentioned_surfsense_doc_ids: list[int] | None = ( None # Optional SurfSense documentation IDs mentioned with @ in the chat ) + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display metadata (id, title, document_type) for every " + "@-mentioned document. Persisted as a ``mentioned-documents`` " + "ContentPart on the user message so reload renders chips " + "without an extra fetch. Optional and additive — when None " + "the user message is persisted without a mentioned-documents " + "part." + ), + ) disabled_tools: list[str] | None = ( None # Optional list of tool names the user has disabled from the UI ) @@ -263,6 +290,16 @@ class RegenerateRequest(BaseModel): ) mentioned_document_ids: list[int] | None = None mentioned_surfsense_doc_ids: list[int] | None = None + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display metadata (id, title, document_type) for every " + "@-mentioned document on the edited user turn. Only used " + "when ``user_query`` is non-None (edit). Persisted as a " + "``mentioned-documents`` ContentPart on the new user " + "message. None means no chip metadata." + ), + ) disabled_tools: list[str] | None = None filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" @@ -336,6 +373,34 @@ class ResumeRequest(BaseModel): filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display metadata forwarded for symmetry with /new_chat and " + "/regenerate. Resume reuses the original interrupted user " + "turn so the server does not write a new user message. " + "Currently unused but accepted to keep request bodies " + "uniform across the three streaming entrypoints." + ), + ) + + +class CancelActiveTurnResponse(BaseModel): + """Response for canceling an active turn on a chat thread.""" + + status: Literal["cancelling", "idle"] + error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"] + retry_after_ms: int | None = None + retry_after_at: int | None = None + + +class TurnStatusResponse(BaseModel): + """Current turn execution status for a thread.""" + + status: Literal["idle", "busy", "cancelling"] + active_turn_id: str | None = None + retry_after_ms: int | None = None + retry_after_at: int | None = None # ============================================================================= diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 9cc1fce58..e64478d38 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (no DB column). Default + # True matches the conservative-allow stance — a BYOK row that the + # route forgot to augment is not pre-judged. The streaming-task + # safety net is the only place a False actually blocks a request. + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map " + "(``litellm.supports_vision``) — there is no DB column. " + "Default True is the conservative-allow stance for unknown / " + "unmapped models." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (see NewLLMConfigRead). + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map. " + "Default True is the conservative-allow stance." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel): seo_title: str | None = None seo_description: str | None = None quota_reserve_tokens: int | None = None + supports_image_input: bool = Field( + default=True, + description=( + "Whether the model accepts image inputs (multimodal vision). " + "Derived server-side: OpenRouter dynamic configs use " + "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " + "authoritative model map (``litellm.supports_vision``). The " + "new-chat selector hints with a 'No image' badge when this is " + "False and there are pending image attachments. The streaming " + "task fails fast only when LiteLLM *explicitly* marks a model " + "as text-only — unknown / unmapped models default-allow." + ), + ) # ============================================================================= diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py index 3edd3e9e4..57265ec8e 100644 --- a/surfsense_backend/app/schemas/stripe.py +++ b/surfsense_backend/app/schemas/stripe.py @@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel): class TokenPurchaseRead(BaseModel): - """Serialized premium token purchase record.""" + """Serialized premium credit purchase record. + + ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The + schema name kept ``Token`` for API back-compat with pinned clients. + """ id: uuid.UUID stripe_checkout_session_id: str stripe_payment_intent_id: str | None = None quantity: int - tokens_granted: int + credit_micros_granted: int amount_total: int | None = None currency: str | None = None status: str @@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel): class TokenPurchaseHistoryResponse(BaseModel): - """Response containing the user's premium token purchases.""" + """Response containing the user's premium credit purchases.""" purchases: list[TokenPurchaseRead] class TokenStripeStatusResponse(BaseModel): - """Response describing token-buying availability and current quota.""" + """Response describing premium-credit-buying availability and balance. + + All ``premium_credit_micros_*`` fields are in micro-USD; the FE + divides by 1_000_000 to display USD. + """ token_buying_enabled: bool - premium_tokens_used: int = 0 - premium_tokens_limit: int = 0 - premium_tokens_remaining: int = 0 + premium_credit_micros_used: int = 0 + premium_credit_micros_limit: int = 0 + premium_credit_micros_remaining: int = 0 diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py index ab2e609dc..d0eeaf5c6 100644 --- a/surfsense_backend/app/schemas/vision_llm.py +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel): class GlobalVisionLLMConfigRead(BaseModel): + """Schema for reading global vision LLM configs from YAML. + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + id: int = Field(...) name: str description: str | None = None @@ -73,3 +82,35 @@ class GlobalVisionLLMConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_tokens: int | None = Field( + default=None, + description=( + "Optional override for the per-call reservation in *tokens* — " + "converted to micro-USD via the model's input/output prices at " + "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." + ), + ) + input_cost_per_token: float | None = Field( + default=None, + description=( + "Optional input price in USD/token. Used by pricing_registration to " + "register custom Azure / OpenRouter aliases with LiteLLM at startup." + ), + ) + output_cost_per_token: float | None = Field( + default=None, + description="Optional output price in USD/token. Pair with input_cost_per_token.", + ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py new file mode 100644 index 000000000..9bbca8669 --- /dev/null +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -0,0 +1,479 @@ +"""Resolve and persist Auto (Fastest) model pins per chat thread. + +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and +persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so +subsequent turns are stable. + +Single-writer invariant: this module is the only writer of +``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in +``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +Therefore a non-NULL value unambiguously means "this thread has an +Auto-resolved pin"; no separate source/policy column is needed. +""" + +from __future__ import annotations + +import hashlib +import logging +import threading +import time +from dataclasses import dataclass +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import NewChatThread +from app.services.quality_score import _QUALITY_TOP_K +from app.services.token_quota_service import TokenQuotaService + +logger = logging.getLogger(__name__) + +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" +_RUNTIME_COOLDOWN_SECONDS = 600 +_HEALTHY_TTL_SECONDS = 45 + +# In-memory runtime cooldown map for configs that recently hard-failed at +# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps +# the same unhealthy config from being reselected immediately during repair. +_runtime_cooldown_until: dict[int, float] = {} +_runtime_cooldown_lock = threading.Lock() + +# Short-TTL "recently healthy" cache for configs that just passed a runtime +# preflight ping. Lets back-to-back turns on the same model skip the probe +# without eroding correctness — entries auto-expire and are wiped any time +# the same config is cooled down or the OR catalogue is refreshed. +_healthy_until: dict[int, float] = {} +_healthy_lock = threading.Lock() + + +@dataclass +class AutoPinResolution: + resolved_llm_config_id: int + resolved_tier: str + from_existing_pin: bool + + +def _is_usable_global_config(cfg: dict) -> bool: + return bool( + cfg.get("id") is not None + and cfg.get("model_name") + and cfg.get("provider") + and cfg.get("api_key") + ) + + +def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] + for cid in stale: + _runtime_cooldown_until.pop(cid, None) + + +def _is_runtime_cooled_down(config_id: int) -> bool: + with _runtime_cooldown_lock: + _prune_runtime_cooldowns() + return config_id in _runtime_cooldown_until + + +def mark_runtime_cooldown( + config_id: int, + *, + reason: str = "rate_limited", + cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS, +) -> None: + """Temporarily suppress a config from Auto selection. + + Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned + config that is currently unhealthy does not get immediately reused on the + same thread during repair. + """ + if cooldown_seconds <= 0: + cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS + until = time.time() + int(cooldown_seconds) + with _runtime_cooldown_lock: + _runtime_cooldown_until[int(config_id)] = until + _prune_runtime_cooldowns() + # A cooled cfg can never be "recently healthy"; drop any stale credit so + # the next turn that resolves to it (after cooldown) re-runs preflight. + clear_healthy(int(config_id)) + logger.info( + "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", + config_id, + reason, + cooldown_seconds, + ) + + +def clear_runtime_cooldown(config_id: int | None = None) -> None: + """Test/ops helper to clear runtime cooldown entries.""" + with _runtime_cooldown_lock: + if config_id is None: + _runtime_cooldown_until.clear() + return + _runtime_cooldown_until.pop(int(config_id), None) + + +def _prune_healthy(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _healthy_until.items() if until <= now] + for cid in stale: + _healthy_until.pop(cid, None) + + +def is_recently_healthy(config_id: int) -> bool: + """Return True if ``config_id`` passed preflight within the TTL window.""" + with _healthy_lock: + _prune_healthy() + return int(config_id) in _healthy_until + + +def mark_healthy( + config_id: int, + *, + ttl_seconds: int = _HEALTHY_TTL_SECONDS, +) -> None: + """Record that ``config_id`` just passed a preflight probe. + + Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The + healthy state is intentionally process-local — it's a latency hint, not a + correctness primitive — so multi-worker drift is acceptable. + """ + if ttl_seconds <= 0: + ttl_seconds = _HEALTHY_TTL_SECONDS + until = time.time() + int(ttl_seconds) + with _healthy_lock: + _healthy_until[int(config_id)] = until + _prune_healthy() + + +def clear_healthy(config_id: int | None = None) -> None: + """Drop one (or all) healthy-cache entries. + + Called from runtime cooldown and OR catalogue refresh so a freshly cooled + or replaced config never carries stale "healthy" credit. + """ + with _healthy_lock: + if config_id is None: + _healthy_until.clear() + return + _healthy_until.pop(int(config_id), None) + + +def _cfg_supports_image_input(cfg: dict) -> bool: + """True if the global cfg can accept image inputs. + + Prefers the explicit ``supports_image_input`` flag (set by the YAML + loader / OpenRouter integration). Falls back to a LiteLLM lookup so + a YAML entry whose flag was somehow stripped doesn't get wrongly + excluded. Default-allows on unknown — the streaming-task safety net + is the actual block, not this filter. + """ + if "supports_image_input" in cfg: + return bool(cfg.get("supports_image_input")) + # Lazy import: provider_capabilities -> llm_config -> services chain; + # importing at module load would create an init-order cycle through + # ``app.config``. + from app.services.provider_capabilities import derive_supports_image_input + + cfg_litellm_params = cfg.get("litellm_params") or {} + base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + return derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + + +def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: + """Return Auto-eligible global cfgs. + + Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime + below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers + can't be picked as the thread's pin. Also excludes configs currently + in runtime cooldown (e.g. temporary 429 bursts). + + When ``requires_image_input`` is True (image turn), additionally + filters out configs whose ``supports_image_input`` resolves to False + so a text-only deployment can't be pinned for an image request. + """ + candidates = [ + cfg + for cfg in config.GLOBAL_LLM_CONFIGS + if _is_usable_global_config(cfg) + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) + and (not requires_image_input or _cfg_supports_image_input(cfg)) + ] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +def _tier_of(cfg: dict) -> str: + return str(cfg.get("billing_tier", "free")).lower() + + +def _is_preferred_premium_auto_config(cfg: dict) -> bool: + """Return True for the operator-preferred premium Auto model.""" + return ( + _tier_of(cfg) == "premium" + and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + and str(cfg.get("model_name", "")).lower() == "gpt-5.4" + ) + + +def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: + """Pick a config with quality-first ranking + deterministic spread. + + Tier policy is lock-first: prefer Tier A (operator-curated YAML) + cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no + Tier A cfg is eligible after upstream filters. Within the locked + pool, sort by ``quality_score`` and pick from the top-K via + ``SHA256(thread_id)`` so different new threads spread across the + best models without ever picking a low-ranked one. + + Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for + structured logging in the caller. + """ + tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")] + pool = tier_a if tier_a else eligible + pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) + top_k = pool[:_QUALITY_TOP_K] + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + idx = int.from_bytes(digest[:8], "big") % len(top_k) + return top_k[idx], len(top_k) + + +def _to_uuid(user_id: str | UUID | None) -> UUID | None: + if user_id is None: + return None + if isinstance(user_id, UUID): + return user_id + try: + return UUID(str(user_id)) + except Exception: + return None + + +async def _is_premium_eligible( + session: AsyncSession, user_id: str | UUID | None +) -> bool: + parsed = _to_uuid(user_id) + if parsed is None: + return False + usage = await TokenQuotaService.premium_get_usage(session, parsed) + return bool(usage.allowed) + + +async def resolve_or_get_pinned_llm_config_id( + session: AsyncSession, + *, + thread_id: int, + search_space_id: int, + user_id: str | UUID | None, + selected_llm_config_id: int, + force_repin_free: bool = False, + exclude_config_ids: set[int] | None = None, + requires_image_input: bool = False, +) -> AutoPinResolution: + """Resolve Auto (Fastest) to one concrete config id and persist the pin. + + For non-auto selections, this function clears any existing pin and returns + the selected id as-is. + + When ``requires_image_input`` is True (the current turn carries an + ``image_url`` block), the candidate pool is filtered to vision-capable + cfgs and any existing pin that can't accept image input is treated as + invalid (force re-pin). If no vision-capable cfg is available the + function raises ``ValueError`` so the streaming task surfaces the same + friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of + silently routing the image to a text-only deployment. + """ + thread = ( + ( + await session.execute( + select(NewChatThread) + .where(NewChatThread.id == thread_id) + .with_for_update(of=NewChatThread) + ) + ) + .unique() + .scalar_one_or_none() + ) + if thread is None: + raise ValueError(f"Thread {thread_id} not found") + if thread.search_space_id != search_space_id: + raise ValueError( + f"Thread {thread_id} does not belong to search space {search_space_id}" + ) + + # Explicit model selected: clear any stale pin. + if selected_llm_config_id != AUTO_FASTEST_ID: + if thread.pinned_llm_config_id is not None: + thread.pinned_llm_config_id = None + await session.commit() + return AutoPinResolution( + resolved_llm_config_id=selected_llm_config_id, + resolved_tier="explicit", + from_existing_pin=False, + ) + + excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} + candidates = [ + c + for c in _global_candidates(requires_image_input=requires_image_input) + if int(c.get("id", 0)) not in excluded_ids + ] + if not candidates: + if requires_image_input: + # Distinguish the "no vision-capable cfg" case from generic + # "no usable cfg" so the streaming task can map this to the + # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. + raise ValueError( + "No vision-capable global LLM configs are available for Auto mode" + ) + raise ValueError("No usable global LLM configs are available for Auto mode") + candidate_by_id = {int(c["id"]): c for c in candidates} + + # Reuse an existing valid pin without re-checking current quota (no silent + # tier switch), unless the caller explicitly requests a forced repin to free + # *or* the turn requires image input but the pin can't handle it. + pinned_id = thread.pinned_llm_config_id + if ( + not force_repin_free + and pinned_id is not None + and int(pinned_id) in candidate_by_id + ): + pinned_cfg = candidate_by_id[int(pinned_id)] + logger.info( + "auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s", + thread_id, + search_space_id, + pinned_id, + _tier_of(pinned_cfg), + ) + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True", + thread_id, + pinned_id, + _tier_of(pinned_cfg), + pinned_cfg.get("auto_pin_tier", "?"), + int(pinned_cfg.get("quality_score") or 0), + ) + return AutoPinResolution( + resolved_llm_config_id=int(pinned_id), + resolved_tier=_tier_of(pinned_cfg), + from_existing_pin=True, + ) + if pinned_id is not None: + # If the pin is *only* invalid because it can't handle the image + # turn (it's still a healthy, usable config in the broader pool), + # log that explicitly so operators can correlate the re-pin with + # the user's image attachment instead of suspecting a cooldown. + if requires_image_input: + try: + pinned_global = next( + c + for c in config.GLOBAL_LLM_CONFIGS + if int(c.get("id", 0)) == int(pinned_id) + ) + except StopIteration: + pinned_global = None + if pinned_global is not None and not _cfg_supports_image_input( + pinned_global + ): + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) + logger.info( + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) + + premium_eligible = ( + False if force_repin_free else await _is_premium_eligible(session, user_id) + ) + if premium_eligible: + premium_candidates = [c for c in candidates if _tier_of(c) == "premium"] + preferred_premium = [ + c for c in premium_candidates if _is_preferred_premium_auto_config(c) + ] + eligible = preferred_premium or premium_candidates + else: + eligible = [c for c in candidates if _tier_of(c) != "premium"] + + if not eligible: + if requires_image_input: + raise ValueError( + "Auto mode could not find a vision-capable LLM config for this user and quota state" + ) + raise ValueError( + "Auto mode could not find an eligible LLM config for this user and quota state" + ) + + selected_cfg, top_k_size = _select_pin(eligible, thread_id) + selected_id = int(selected_cfg["id"]) + selected_tier = _tier_of(selected_cfg) + + thread.pinned_llm_config_id = selected_id + await session.commit() + + if force_repin_free: + logger.info( + "auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + ) + + if pinned_id is None: + logger.info( + "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + selected_id, + selected_tier, + premium_eligible, + ) + else: + logger.info( + "auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + selected_tier, + premium_eligible, + ) + + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False", + thread_id, + selected_id, + selected_tier, + selected_cfg.get("auto_pin_tier", "?"), + int(selected_cfg.get("quality_score") or 0), + top_k_size, + ) + + return AutoPinResolution( + resolved_llm_config_id=selected_id, + resolved_tier=selected_tier, + from_existing_pin=False, + ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py new file mode 100644 index 000000000..92ccd6a78 --- /dev/null +++ b/surfsense_backend/app/services/billable_calls.py @@ -0,0 +1,566 @@ +""" +Per-call billable wrapper for image generation, vision LLM extraction, and +any other short-lived premium operation that must charge against the user's +shared premium credit pool. + +The ``billable_call`` async context manager encapsulates the standard +"reserve → execute → finalize / release → record audit row" lifecycle in a +single primitive so callers (the image-generation REST route and the +vision-LLM wrapper used during indexing) don't have to re-implement it. + +KEY DESIGN POINTS (issue A, B): + +1. **Session isolation.** ``billable_call`` takes no caller transaction. + All ``TokenQuotaService.premium_*`` calls and the audit-row insert run + inside their own session context. Route callers use + ``shielded_async_session()`` by default; Celery callers can provide a + worker-loop-safe session factory. This guarantees that quota + commit/rollback can never accidentally flush or roll back rows the caller + has staged in its main session (e.g. a freshly-created + ``ImageGeneration`` row). + +2. **ContextVar safety.** The accumulator is scoped via + :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a + nested ``billable_call`` inside an outer chat turn cannot corrupt the + chat turn's accumulator. + +3. **Free configs are still audited.** Free calls bypass the reserve / + finalize dance entirely but still record a ``TokenUsage`` audit row with + the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution + pipeline complete for analytics even when nothing is debited. + +4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is + responsible for translating that into HTTP 402. We *do not* catch the + denial inside ``billable_call`` — letting it propagate also prevents + the image-generation route from creating an ``ImageGeneration`` row + for a request that never actually ran. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import shielded_async_session +from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, +) +from app.services.token_tracking_service import ( + TurnTokenAccumulator, + record_token_usage, + scoped_turn, +) + +logger = logging.getLogger(__name__) + +AUDIT_TIMEOUT_SECONDS = 10.0 +BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset( + {"video_presentation_generation", "podcast_generation"} +) +BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]] + + +class QuotaInsufficientError(Exception): + """Raised when ``TokenQuotaService.premium_reserve`` denies a billable + call because the user has exhausted their premium credit pool. + + The route handler should catch this and return HTTP 402 Payment + Required (or the equivalent for the surface area). Outside of the HTTP + layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing) + callers may catch this and degrade gracefully — e.g. fall back to OCR + when vision is unavailable. + """ + + def __init__( + self, + *, + usage_type: str, + used_micros: int, + limit_micros: int, + remaining_micros: int, + ) -> None: + self.usage_type = usage_type + self.used_micros = used_micros + self.limit_micros = limit_micros + self.remaining_micros = remaining_micros + super().__init__( + f"Premium credit exhausted for {usage_type}: " + f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" + ) + + +class BillingSettlementError(Exception): + """Raised when a premium call completed but credit settlement failed.""" + + def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None: + self.usage_type = usage_type + self.user_id = user_id + super().__init__( + f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}" + ) + + +async def _rollback_safely(session: AsyncSession) -> None: + rollback = getattr(session, "rollback", None) + if rollback is not None: + with suppress(Exception): + await rollback() + + +async def _record_audit_best_effort( + *, + session_factory: BillableSessionFactory, + usage_type: str, + search_space_id: int, + user_id: UUID, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + cost_micros: int, + model_breakdown: dict[str, Any], + call_details: dict[str, Any] | None, + thread_id: int | None, + message_id: int | None, + audit_label: str, + timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, +) -> None: + """Persist a TokenUsage row without letting audit failure block callers. + + Premium settlement is mandatory, but TokenUsage is an audit trail. If the + audit insert or commit hangs, user-facing artifacts such as videos and + podcasts must still be able to transition to READY after settlement. + """ + audit_thread_id = ( + None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id + ) + + async def _persist() -> None: + logger.info( + "[billable_call] audit start label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + async with session_factory() as audit_session: + try: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost_micros=cost_micros, + model_breakdown=model_breakdown, + call_details=call_details, + thread_id=audit_thread_id, + message_id=message_id, + ) + logger.info( + "[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + await audit_session.commit() + logger.info( + "[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + except BaseException: + await _rollback_safely(audit_session) + raise + + try: + await asyncio.wait_for(_persist(), timeout=timeout_seconds) + except TimeoutError: + logger.warning( + "[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s " + "timeout=%.1fs total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + timeout_seconds, + total_tokens, + cost_micros, + ) + except Exception: + logger.exception( + "[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + + +@asynccontextmanager +async def billable_call( + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None = None, + quota_reserve_micros_override: int | None = None, + usage_type: str, + thread_id: int | None = None, + message_id: int | None = None, + call_details: dict[str, Any] | None = None, + billable_session_factory: BillableSessionFactory | None = None, + audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, +) -> AsyncIterator[TurnTokenAccumulator]: + """Wrap a single billable LLM/image call. + + Args: + user_id: Owner of the credit pool to debit. For vision-LLM during + indexing this is the *search-space owner* (issue M), not the + triggering user. + search_space_id: Required — recorded on the ``TokenUsage`` audit row. + billing_tier: ``"premium"`` debits; anything else (``"free"``) skips + the reserve/finalize dance but still records an audit row with + the captured cost. + base_model: Used by :func:`estimate_call_reserve_micros` to compute + a worst-case reservation from LiteLLM's pricing table. + quota_reserve_tokens: Optional per-config override for the chat-style + reserve estimator (vision LLM uses this). + quota_reserve_micros_override: Optional flat micro-USD reservation + (image generation uses this — its cost shape is per-image, not + per-token). + usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc. + Recorded on the ``TokenUsage`` row. + thread_id, message_id: Optional FK columns on ``TokenUsage``. + call_details: Optional per-call metadata (model name, parameters) + forwarded to ``record_token_usage``. + billable_session_factory: Optional async context factory used for + reserve/finalize/release/audit sessions. Defaults to + ``shielded_async_session`` for route callers; Celery callers pass + a worker-loop-safe session factory. + audit_timeout_seconds: Upper bound for TokenUsage audit persistence. + Audit failure is best-effort and does not undo successful + settlement. + + Yields: + The ``TurnTokenAccumulator`` scoped to this call. The caller invokes + the underlying LLM/image API while inside the ``async with``; the + ``TokenTrackingCallback`` populates the accumulator automatically. + + Raises: + QuotaInsufficientError: when premium and ``premium_reserve`` denies. + """ + is_premium = billing_tier == "premium" + session_factory = billable_session_factory or shielded_async_session + + async with scoped_turn() as acc: + # ---------- Free path: just audit ------------------------------- + if not is_premium: + try: + yield acc + finally: + # Always audit, even on exception, so we capture cost when + # provider returns successfully but the caller raises later. + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=acc.total_cost_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="free", + timeout_seconds=audit_timeout_seconds, + ) + return + + # ---------- Premium path: reserve → execute → finalize ---------- + if quota_reserve_micros_override is not None: + reserve_micros = max(1, int(quota_reserve_micros_override)) + else: + reserve_micros = estimate_call_reserve_micros( + base_model=base_model or "", + quota_reserve_tokens=quota_reserve_tokens, + ) + + request_id = str(uuid4()) + + async with session_factory() as quota_session: + reserve_result = await TokenQuotaService.premium_reserve( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + reserve_micros=reserve_micros, + ) + + if not reserve_result.allowed: + logger.info( + "[billable_call] reserve DENIED user=%s usage_type=%s " + "reserve=%d used=%d limit=%d remaining=%d", + user_id, + usage_type, + reserve_micros, + reserve_result.used, + reserve_result.limit, + reserve_result.remaining, + ) + raise QuotaInsufficientError( + usage_type=usage_type, + used_micros=reserve_result.used, + limit_micros=reserve_result.limit, + remaining_micros=reserve_result.remaining, + ) + + logger.info( + "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d " + "(remaining=%d)", + user_id, + usage_type, + reserve_micros, + reserve_result.remaining, + ) + + try: + yield acc + except BaseException: + # Release on any failure (including QuotaInsufficientError raised + # from a downstream call, asyncio cancellation, etc.). We use + # BaseException so cancellation also releases. + try: + async with session_factory() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] premium_release failed for user=%s " + "reserve_micros=%d (reservation will be GC'd by quota " + "reconciliation if/when implemented)", + user_id, + reserve_micros, + ) + raise + + # ---------- Success: finalize + audit ---------------------------- + actual_micros = acc.total_cost_micros + try: + logger.info( + "[billable_call] finalize start user=%s usage_type=%s actual=%d " + "reserved=%d thread=%s", + user_id, + usage_type, + actual_micros, + reserve_micros, + thread_id, + ) + async with session_factory() as quota_session: + final_result = await TokenQuotaService.premium_finalize( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + actual_micros=actual_micros, + reserved_micros=reserve_micros, + ) + logger.info( + "[billable_call] finalize user=%s usage_type=%s actual=%d " + "reserved=%d → used=%d/%d (remaining=%d)", + user_id, + usage_type, + actual_micros, + reserve_micros, + final_result.used, + final_result.limit, + final_result.remaining, + ) + except Exception as finalize_exc: + # Last-ditch: if finalize itself fails, we must at least release + # so the reservation doesn't leak. + logger.exception( + "[billable_call] premium_finalize failed for user=%s; " + "attempting release", + user_id, + ) + try: + async with session_factory() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] release after finalize failure ALSO failed " + "for user=%s", + user_id, + ) + raise BillingSettlementError( + usage_type=usage_type, + user_id=user_id, + cause=finalize_exc, + ) from finalize_exc + + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=actual_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="premium", + timeout_seconds=audit_timeout_seconds, + ) + + +async def _resolve_agent_billing_for_search_space( + session: AsyncSession, + search_space_id: int, + *, + thread_id: int | None = None, +) -> tuple[UUID, str, str]: + """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space + agent LLM. + + Used by Celery tasks (podcast generation, video presentation) to bill the + search-space owner's premium credit pool when the agent LLM is premium. + + Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + + - Search space not found / no ``agent_llm_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + * ``thread_id`` is set: delegate to + ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and + recurse into the resolved id. Reuses chat's existing pin if present + so the same model bills for chat + downstream podcast/video. If the + user is not premium-eligible, the pin service auto-restricts to free + deployments — denial only happens later in + ``billable_call.premium_reserve`` if the pin really is premium and + credit ran out mid-flow. + * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat + for any future direct-API path; today both Celery tasks always pass + ``thread_id``. + - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]`` + (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), + ``base_model = litellm_params.get("base_model") or model_name`` — + NOT provider-prefixed, matching chat's cost-map lookup convention. + - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches + ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); + ``base_model`` from ``litellm_params`` or ``model_name``. + + Note on imports: ``llm_service``, ``auto_model_pin_service``, and + ``llm_router_service`` are imported lazily inside the function body to + avoid hoisting litellm side-effects (``litellm.callbacks = + [token_tracker]``, ``litellm.drop_params``, etc.) into + ``billable_calls.py``'s module load path. + """ + from sqlalchemy import select + + from app.db import NewLLMConfig, SearchSpace + + result = await session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if search_space is None: + raise ValueError(f"Search space {search_space_id} not found") + + agent_llm_id = search_space.agent_llm_id + if agent_llm_id is None: + raise ValueError( + f"Search space {search_space_id} has no agent_llm_id configured" + ) + + owner_user_id: UUID = search_space.user_id + + from app.services.auto_model_pin_service import ( + AUTO_FASTEST_ID, + resolve_or_get_pinned_llm_config_id, + ) + + if agent_llm_id == AUTO_FASTEST_ID: + if thread_id is None: + return owner_user_id, "free", "auto" + try: + resolution = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=str(owner_user_id), + selected_llm_config_id=AUTO_FASTEST_ID, + ) + except ValueError: + logger.warning( + "[agent_billing] Auto-mode pin resolution failed for " + "search_space=%s thread=%s; falling back to free", + search_space_id, + thread_id, + exc_info=True, + ) + return owner_user_id, "free", "auto" + agent_llm_id = resolution.resolved_llm_config_id + + if agent_llm_id < 0: + from app.services.llm_service import get_global_llm_config + + cfg = get_global_llm_config(agent_llm_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + litellm_params = cfg.get("litellm_params") or {} + base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" + return owner_user_id, billing_tier, base_model + + nlc_result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == agent_llm_id, + NewLLMConfig.search_space_id == search_space_id, + ) + ) + nlc = nlc_result.scalars().first() + base_model = "" + if nlc is not None: + litellm_params = nlc.litellm_params or {} + base_model = litellm_params.get("base_model") or nlc.model_name or "" + return owner_user_id, "free", base_model + + +__all__ = [ + "BillingSettlementError", + "QuotaInsufficientError", + "_resolve_agent_billing_for_search_space", + "billable_call", +] + + +# Re-export the config knob so callers don't have to import config just for +# the default image reserve. +DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index a8abe4aa8..edfab1d15 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -408,12 +408,37 @@ class ComposioService: files = [] next_token = None if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) # Try direct access first, then nested - files = data.get("files", []) or data.get("data", {}).get("files", []) + files = ( + data.get("files", []) + or ( + inner_data.get("files", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("files", []) + ) next_token = ( data.get("nextPageToken") or data.get("next_page_token") - or data.get("data", {}).get("nextPageToken") + or ( + inner_data.get("nextPageToken") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("next_page_token") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("nextPageToken") + or response_data.get("next_page_token") ) elif isinstance(data, list): files = data @@ -819,24 +844,61 @@ class ComposioService: next_token = None result_size_estimate = None if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) messages = ( data.get("messages", []) - or data.get("data", {}).get("messages", []) + or ( + inner_data.get("messages", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("messages", []) or data.get("emails", []) + or ( + inner_data.get("emails", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("emails", []) ) # Check for pagination token in various possible locations next_token = ( data.get("nextPageToken") or data.get("next_page_token") - or data.get("data", {}).get("nextPageToken") - or data.get("data", {}).get("next_page_token") + or ( + inner_data.get("nextPageToken") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("next_page_token") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("nextPageToken") + or response_data.get("next_page_token") ) # Extract resultSizeEstimate if available (Gmail API provides this) result_size_estimate = ( data.get("resultSizeEstimate") or data.get("result_size_estimate") - or data.get("data", {}).get("resultSizeEstimate") - or data.get("data", {}).get("result_size_estimate") + or ( + inner_data.get("resultSizeEstimate") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("result_size_estimate") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("resultSizeEstimate") + or response_data.get("result_size_estimate") ) elif isinstance(data, list): messages = data @@ -864,7 +926,7 @@ class ComposioService: try: result = await self.execute_tool( connected_account_id=connected_account_id, - tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID", + tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID", params={"message_id": message_id}, # snake_case entity_id=entity_id, ) @@ -872,7 +934,13 @@ class ComposioService: if not result.get("success"): return None, result.get("error", "Unknown error") - return result.get("data"), None + data = result.get("data") + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + return inner_data.get("response_data", inner_data), None + + return data, None except Exception as e: logger.error(f"Failed to get Gmail message detail: {e!s}") @@ -928,10 +996,27 @@ class ComposioService: # Try different possible response structures events = [] if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) events = ( data.get("items", []) - or data.get("data", {}).get("items", []) + or ( + inner_data.get("items", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("items", []) or data.get("events", []) + or ( + inner_data.get("events", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("events", []) ) elif isinstance(data, list): events = data diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 7c55da2e5..45bcfd00f 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -1,6 +1,8 @@ import asyncio +import os import time from datetime import datetime +from threading import Lock from typing import Any import httpx @@ -2769,12 +2771,22 @@ class ConnectorService: """ Get all available (enabled) connector types for a search space. + Phase 1.4: results are cached per ``search_space_id`` for + :data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session + identity — the cached value is plain data, safe to share across + requests. Invalidate on connector add/update/delete via + :func:`invalidate_connector_discovery_cache`. + Args: search_space_id: The search space ID Returns: List of SearchSourceConnectorType enums for enabled connectors """ + cached = _get_cached_connectors(search_space_id) + if cached is not None: + return list(cached) + query = ( select(SearchSourceConnector.connector_type) .filter( @@ -2784,8 +2796,9 @@ class ConnectorService: ) result = await self.session.execute(query) - connector_types = result.scalars().all() - return list(connector_types) + connector_types = list(result.scalars().all()) + _set_cached_connectors(search_space_id, connector_types) + return connector_types async def get_available_document_types( self, @@ -2794,12 +2807,22 @@ class ConnectorService: """ Get all document types that have at least one document in the search space. + Phase 1.4: cached per ``search_space_id`` for + :data:`_DISCOVERY_TTL_SECONDS`. Invalidate via + :func:`invalidate_connector_discovery_cache` when a connector + finishes indexing new documents (or document types are otherwise + added/removed). + Args: search_space_id: The search space ID Returns: List of document type strings that have documents indexed """ + cached = _get_cached_doc_types(search_space_id) + if cached is not None: + return list(cached) + from sqlalchemy import distinct from app.db import Document @@ -2809,5 +2832,164 @@ class ConnectorService: ) result = await self.session.execute(query) - doc_types = result.scalars().all() - return [str(dt) for dt in doc_types] + doc_types = [str(dt) for dt in result.scalars().all()] + _set_cached_doc_types(search_space_id, doc_types) + return doc_types + + +# --------------------------------------------------------------------------- +# Connector / document-type discovery TTL cache (Phase 1.4) +# --------------------------------------------------------------------------- +# +# Both ``get_available_connectors`` and ``get_available_document_types`` are +# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query +# hits Postgres and contributes to per-turn agent build latency. Their +# results change infrequently — only when the user adds/edits/removes a +# connector, or when an indexer commits a new document type. A short TTL +# cache (default 30s, env-tunable) collapses N concurrent calls into one +# DB roundtrip with bounded staleness. +# +# Invalidation: connector mutation routes (create / update / delete) call +# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the +# entry for the affected space. Multi-replica deployments still pay one +# DB roundtrip per replica per TTL window, which is fine — staleness is +# bounded and the alternative (cross-replica fanout) is not worth the +# coupling here. + +_DISCOVERY_TTL_SECONDS: float = float( + os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30") +) + +# Per-search-space caches. Keyed by ``search_space_id``; value is +# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock — +# read-mostly workload, sub-microsecond contention. +_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {} +_doc_types_cache: dict[int, tuple[float, list[str]]] = {} +_cache_lock = Lock() + + +def _get_cached_connectors( + search_space_id: int, +) -> list[SearchSourceConnectorType] | None: + if _DISCOVERY_TTL_SECONDS <= 0: + return None + with _cache_lock: + entry = _connectors_cache.get(search_space_id) + if entry is None: + return None + expires_at, payload = entry + if time.monotonic() >= expires_at: + _connectors_cache.pop(search_space_id, None) + return None + return payload + + +def _set_cached_connectors( + search_space_id: int, payload: list[SearchSourceConnectorType] +) -> None: + if _DISCOVERY_TTL_SECONDS <= 0: + return + expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS + with _cache_lock: + _connectors_cache[search_space_id] = (expires_at, list(payload)) + + +def _get_cached_doc_types(search_space_id: int) -> list[str] | None: + if _DISCOVERY_TTL_SECONDS <= 0: + return None + with _cache_lock: + entry = _doc_types_cache.get(search_space_id) + if entry is None: + return None + expires_at, payload = entry + if time.monotonic() >= expires_at: + _doc_types_cache.pop(search_space_id, None) + return None + return payload + + +def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None: + if _DISCOVERY_TTL_SECONDS <= 0: + return + expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS + with _cache_lock: + _doc_types_cache[search_space_id] = (expires_at, list(payload)) + + +def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None: + """Drop cached discovery results for ``search_space_id`` (or all spaces). + + Connector CRUD routes / indexer pipelines call this when they mutate + the rows backing :func:`ConnectorService.get_available_connectors` / + :func:`get_available_document_types`. ``None`` clears every space — + useful in tests and on bulk imports. + """ + with _cache_lock: + if search_space_id is None: + _connectors_cache.clear() + _doc_types_cache.clear() + else: + _connectors_cache.pop(search_space_id, None) + _doc_types_cache.pop(search_space_id, None) + + +def _invalidate_connectors_only(search_space_id: int | None = None) -> None: + with _cache_lock: + if search_space_id is None: + _connectors_cache.clear() + else: + _connectors_cache.pop(search_space_id, None) + + +def _invalidate_doc_types_only(search_space_id: int | None = None) -> None: + with _cache_lock: + if search_space_id is None: + _doc_types_cache.clear() + else: + _doc_types_cache.pop(search_space_id, None) + + +def _register_invalidation_listeners() -> None: + """Wire SQLAlchemy ORM events so cache stays consistent automatically. + + Listening on ``after_insert`` / ``after_update`` / ``after_delete`` + means every successful INSERT/UPDATE/DELETE that goes through the ORM + invalidates the affected search space's cached discovery payload — + no need to sprinkle ``invalidate_*`` calls across 30+ connector + routes. Bulk operations that bypass the ORM (e.g. + ``session.execute(insert(...))`` without a mapped object) still need + explicit invalidation; document indexers already commit through the + ORM so document-type discovery is covered. + """ + from sqlalchemy import event + + # Imported here (not at module top) to avoid a circular import: + # app.services.connector_service is itself imported from app.db's + # ecosystem indirectly via several CRUD modules. + from app.db import Document, SearchSourceConnector + + def _connector_changed(_mapper, _connection, target) -> None: + sid = getattr(target, "search_space_id", None) + if sid is not None: + _invalidate_connectors_only(int(sid)) + + def _document_changed(_mapper, _connection, target) -> None: + sid = getattr(target, "search_space_id", None) + if sid is not None: + _invalidate_doc_types_only(int(sid)) + + for evt in ("after_insert", "after_update", "after_delete"): + event.listen(SearchSourceConnector, evt, _connector_changed) + event.listen(Document, evt, _document_changed) + + +try: + _register_invalidation_listeners() +except Exception: # pragma: no cover - defensive; never block module import + import logging as _logging + + _logging.getLogger(__name__).exception( + "Failed to register connector discovery cache invalidation listeners; " + "stale cache risk: explicit invalidate_connector_discovery_cache calls " + "may be required." + ) diff --git a/surfsense_backend/app/services/gmail/tool_metadata_service.py b/surfsense_backend/app/services/gmail/tool_metadata_service.py index c903e24af..4855c1cc9 100644 --- a/surfsense_backend/app/services/gmail/tool_metadata_service.py +++ b/surfsense_backend/app/services/gmail/tool_metadata_service.py @@ -17,7 +17,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -78,14 +78,49 @@ class GmailToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session - async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: - if ( + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + return cca_id + + def _unwrap_composio_data(self, data: Any) -> Any: + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + return data + + async def _execute_composio_gmail_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict[str, Any], + ) -> tuple[Any, str | None]: + result = await ComposioService().execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Gmail error") + return self._unwrap_composio_data(result.get("data")), None + + async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: + if self._is_composio_connector(connector): + raise ValueError( + "Composio Gmail connectors must use Composio tool execution" + ) config_data = dict(connector.config) @@ -139,6 +174,12 @@ class GmailToolMetadataService: if not connector: return True + if self._is_composio_connector(connector): + _profile, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_GET_PROFILE", {"user_id": "me"} + ) + return bool(error) + creds = await self._build_credentials(connector) service = build("gmail", "v1", credentials=creds) await asyncio.get_event_loop().run_in_executor( @@ -221,14 +262,21 @@ class GmailToolMetadataService: ) connector = result.scalar_one_or_none() if connector: - creds = await self._build_credentials(connector) - service = build("gmail", "v1", credentials=creds) - profile = await asyncio.get_event_loop().run_in_executor( - None, - lambda service=service: ( - service.users().getProfile(userId="me").execute() - ), - ) + if self._is_composio_connector(connector): + profile, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_GET_PROFILE", {"user_id": "me"} + ) + if error: + raise RuntimeError(error) + else: + creds = await self._build_credentials(connector) + service = build("gmail", "v1", credentials=creds) + profile = await asyncio.get_event_loop().run_in_executor( + None, + lambda service=service: ( + service.users().getProfile(userId="me").execute() + ), + ) acc_dict["email"] = profile.get("emailAddress", "") except Exception: logger.warning( @@ -298,6 +346,23 @@ class GmailToolMetadataService: Returns ``None`` on any failure so callers can degrade gracefully. """ try: + if self._is_composio_connector(connector): + if not draft_id: + draft_id = await self._find_composio_draft_id(connector, message_id) + if not draft_id: + return None + + draft, error = await self._execute_composio_gmail_tool( + connector, + "GMAIL_GET_DRAFT", + {"user_id": "me", "draft_id": draft_id, "format": "full"}, + ) + if error or not isinstance(draft, dict): + return None + + payload = draft.get("message", {}).get("payload", {}) + return self._extract_body_from_payload(payload) + creds = await self._build_credentials(connector) service = build("gmail", "v1", credentials=creds) @@ -326,6 +391,33 @@ class GmailToolMetadataService: ) return None + async def _find_composio_draft_id( + self, connector: SearchSourceConnector, message_id: str + ) -> str | None: + page_token = "" + while True: + params: dict[str, Any] = { + "user_id": "me", + "max_results": 100, + "verbose": False, + } + if page_token: + params["page_token"] = page_token + + data, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_LIST_DRAFTS", params + ) + if error or not isinstance(data, dict): + return None + + for draft in data.get("drafts", []): + if draft.get("message", {}).get("id") == message_id: + return draft.get("id") + + page_token = data.get("nextPageToken") or data.get("next_page_token") or "" + if not page_token: + return None + async def _find_draft_id(self, service: Any, message_id: str) -> str | None: """Resolve a draft ID from its message ID by scanning drafts.list.""" try: diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 20426f3bc..602a55738 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -14,6 +14,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) +from app.services.composio_service import ComposioService from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -21,7 +22,6 @@ from app.utils.document_converters import ( generate_document_summary, generate_unique_identifier_hash, ) -from app.utils.google_credentials import build_composio_credentials logger = logging.getLogger(__name__) @@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService: logger.warning("Document %s not found in KB", document_id) return {"status": "not_indexed"} - creds = await self._build_credentials_for_connector(connector_id) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - calendar_id = (document.document_metadata or {}).get( "calendar_id" ) or "primary" - live_event = await loop.run_in_executor( - None, - lambda: ( - service.events() - .get(calendarId=calendar_id, eventId=event_id) - .execute() - ), - ) + connector = await self._get_connector(connector_id) + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_EVENTS_GET", + params={"calendar_id": calendar_id, "event_id": event_id}, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get("error", "Unknown Composio Calendar error") + ) + live_event = composio_result.get("data", {}) + if isinstance(live_event, dict): + live_event = live_event.get("data", live_event) + if isinstance(live_event, dict): + live_event = live_event.get("response_data", live_event) + else: + creds = await self._build_credentials_for_connector(connector_id) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + live_event = await loop.run_in_executor( + None, + lambda: ( + service.events() + .get(calendarId=calendar_id, eventId=event_id) + .execute() + ), + ) event_summary = live_event.get("summary", "") description = live_event.get("description", "") @@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService: await self.db_session.rollback() return {"status": "error", "message": str(e)} - async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: + async def _get_connector(self, connector_id: int) -> SearchSourceConnector: result = await self.db_session.execute( select(SearchSourceConnector).where( SearchSourceConnector.id == connector_id @@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService: connector = result.scalar_one_or_none() if not connector: raise ValueError(f"Connector {connector_id} not found") + return connector + async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: + connector = await self._get_connector(connector_id) if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) - raise ValueError("Composio connected_account_id not found") + raise ValueError( + "Composio Calendar connectors must use Composio tool execution" + ) config_data = dict(connector.config) diff --git a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py index c7bfe1d50..7e50ab039 100644 --- a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py +++ b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py @@ -16,7 +16,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session - async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: - if ( + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: raise ValueError("Composio connected_account_id not found") + return cca_id + + async def _execute_composio_calendar_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict, + ) -> tuple[dict | list | None, str | None]: + service = ComposioService() + result = await service.execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Calendar error") + + data = result.get("data") + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner), None + return inner, None + return data, None + + async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: + if self._is_composio_connector(connector): + raise ValueError( + "Composio Calendar connectors must use Composio tool execution" + ) config_data = dict(connector.config) @@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService: if not connector: return True + if self._is_composio_connector(connector): + _data, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_GET_CALENDAR", + {"calendar_id": "primary"}, + ) + return bool(error) + creds = await self._build_credentials(connector) loop = asyncio.get_event_loop() await loop.run_in_executor( @@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService: timezone_str = "" if connector: try: - creds = await self._build_credentials(connector) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if self._is_composio_connector(connector): + cal_list, cal_error = await self._execute_composio_calendar_tool( + connector, "GOOGLECALENDAR_LIST_CALENDARS", {} + ) + if cal_error: + raise RuntimeError(cal_error) + ( + settings, + settings_error, + ) = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_SETTINGS_GET", + {"setting": "timezone"}, + ) + if not settings_error and isinstance(settings, dict): + timezone_str = settings.get("value", "") + else: + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) - cal_list = await loop.run_in_executor( - None, lambda: service.calendarList().list().execute() - ) - for cal in cal_list.get("items", []): + cal_list = await loop.run_in_executor( + None, lambda: service.calendarList().list().execute() + ) + + tz_setting = await loop.run_in_executor( + None, + lambda: service.settings().get(setting="timezone").execute(), + ) + timezone_str = tz_setting.get("value", "") + + calendar_items = [] + if isinstance(cal_list, dict): + calendar_items = ( + cal_list.get("items") or cal_list.get("calendars") or [] + ) + elif isinstance(cal_list, list): + calendar_items = cal_list + + for cal in calendar_items: calendars.append( { "id": cal.get("id", ""), @@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService: "primary": cal.get("primary", False), } ) - - tz_setting = await loop.run_in_executor( - None, - lambda: service.settings().get(setting="timezone").execute(), - ) - timezone_str = tz_setting.get("value", "") except Exception: logger.warning( "Failed to fetch calendars/timezone for connector %s", @@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService: event_dict = event.to_dict() try: - creds = await self._build_credentials(connector) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) calendar_id = event.calendar_id or "primary" - live_event = await loop.run_in_executor( - None, - lambda: ( - service.events() - .get(calendarId=calendar_id, eventId=event.event_id) - .execute() - ), - ) + if self._is_composio_connector(connector): + live_event, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_EVENTS_GET", + {"calendar_id": calendar_id, "event_id": event.event_id}, + ) + if error: + raise RuntimeError(error) + else: + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + live_event = await loop.run_in_executor( + None, + lambda: ( + service.events() + .get(calendarId=calendar_id, eventId=event.event_id) + .execute() + ), + ) event_dict["summary"] = live_event.get("summary", event_dict["summary"]) event_dict["description"] = live_event.get( @@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService: ) -> dict: resolved = await self._resolve_event(search_space_id, user_id, event_ref) if not resolved: + live_resolved = await self._resolve_live_event( + search_space_id, user_id, event_ref + ) + if not live_resolved: + return { + "error": ( + f"Event '{event_ref}' not found in your indexed or live Google Calendar events. " + "This could mean: (1) the event doesn't exist, " + "(2) the event name is different, or " + "(3) the connected calendar account cannot access it." + ) + } + + connector, live_event = live_resolved + account = GoogleCalendarAccount.from_connector(connector) + acc_dict = account.to_dict() + auth_expired = await self._check_account_health(connector.id) + acc_dict["auth_expired"] = auth_expired + if auth_expired: + await self._persist_auth_expired(connector.id) + return { - "error": ( - f"Event '{event_ref}' not found in your indexed Google Calendar events. " - "This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the event name is different." - ) + "account": acc_dict, + "event": self._event_dict_from_live_event(live_event), } document, connector = resolved @@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService: if row: return row[0], row[1] return None + + async def _resolve_live_event( + self, search_space_id: int, user_id: str, event_ref: str + ) -> tuple[SearchSourceConnector, dict] | None: + result = await self._db_session.execute( + select(SearchSourceConnector) + .filter( + and_( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES), + ) + ) + .order_by(SearchSourceConnector.last_indexed_at.desc()) + ) + connectors = result.scalars().all() + + for connector in connectors: + try: + events = await self._search_live_events(connector, event_ref) + except Exception: + logger.warning( + "Failed to search live calendar events for connector %s", + connector.id, + exc_info=True, + ) + continue + + if not events: + continue + + normalized_ref = event_ref.strip().lower() + exact_match = next( + ( + event + for event in events + if event.get("summary", "").strip().lower() == normalized_ref + ), + None, + ) + return connector, exact_match or events[0] + + return None + + async def _search_live_events( + self, connector: SearchSourceConnector, event_ref: str + ) -> list[dict]: + if self._is_composio_connector(connector): + data, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_EVENTS_LIST", + { + "calendar_id": "primary", + "q": event_ref, + "max_results": 10, + "single_events": True, + "order_by": "startTime", + }, + ) + if error: + raise RuntimeError(error) + if isinstance(data, dict): + return data.get("items") or data.get("events") or [] + return data if isinstance(data, list) else [] + + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + response = await loop.run_in_executor( + None, + lambda: ( + service.events() + .list( + calendarId="primary", + q=event_ref, + maxResults=10, + singleEvents=True, + orderBy="startTime", + ) + .execute() + ), + ) + return response.get("items", []) + + def _event_dict_from_live_event(self, event: dict) -> dict: + start_data = event.get("start", {}) + end_data = event.get("end", {}) + return { + "event_id": event.get("id", ""), + "summary": event.get("summary", "No Title"), + "start": start_data.get("dateTime", start_data.get("date", "")), + "end": end_data.get("dateTime", end_data.get("date", "")), + "description": event.get("description", ""), + "location": event.get("location", ""), + "attendees": [ + { + "email": attendee.get("email", ""), + "responseStatus": attendee.get("responseStatus", ""), + } + for attendee in event.get("attendees", []) + ], + "calendar_id": event.get("calendarId", "primary"), + "document_id": None, + "indexed_at": None, + } diff --git a/surfsense_backend/app/services/google_drive/tool_metadata_service.py b/surfsense_backend/app/services/google_drive/tool_metadata_service.py index 221bee14a..0f654bc78 100644 --- a/surfsense_backend/app/services/google_drive/tool_metadata_service.py +++ b/surfsense_backend/app/services/google_drive/tool_metadata_service.py @@ -13,7 +13,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + return cca_id + + async def _execute_composio_drive_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict, + ) -> tuple[dict | list | None, str | None]: + result = await ComposioService().execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Drive error") + data = result.get("data") + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner), None + return inner, None + return data, None + async def get_creation_context(self, search_space_id: int, user_id: str) -> dict: accounts = await self._get_google_drive_accounts(search_space_id, user_id) @@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService: if not connector: return True - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if self._is_composio_connector(connector): + _data, error = await self._execute_composio_drive_tool( + connector, + "GOOGLEDRIVE_LIST_FILES", + { + "q": "trashed = false", + "page_size": 1, + "fields": "files(id)", + }, + ) + return bool(error) client = GoogleDriveClient( session=self._db_session, connector_id=connector_id, - credentials=pre_built_creds, ) await client.list_files( query="trashed = false", page_size=1, fields="files(id)" @@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService: parent_folders[connector_id] = [] continue - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if self._is_composio_connector(connector): + data, error = await self._execute_composio_drive_tool( + connector, + "GOOGLEDRIVE_LIST_FILES", + { + "q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents", + "fields": "files(id,name)", + "page_size": 50, + }, + ) + if error: + logger.warning( + "Failed to list folders for connector %s: %s", + connector_id, + error, + ) + parent_folders[connector_id] = [] + continue + folders = [] + if isinstance(data, dict): + folders = data.get("files", []) + elif isinstance(data, list): + folders = data + parent_folders[connector_id] = [ + {"folder_id": f["id"], "name": f["name"]} + for f in folders + if f.get("id") and f.get("name") + ] + continue client = GoogleDriveClient( session=self._db_session, connector_id=connector_id, - credentials=pre_built_creds, ) folders, _, error = await client.list_files( diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index f45a6ab63..b4de2a0bf 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,6 +20,8 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing @@ -152,12 +154,12 @@ class ImageGenRouterService: return None # Build model string + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] else: - provider = config.get("provider", "").upper() provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" + model_string = f"{provider_prefix}/{config['model_name']}" # Build litellm params litellm_params: dict[str, Any] = { @@ -165,9 +167,16 @@ class ImageGenRouterService: "api_key": config.get("api_key"), } - # Add optional api_base - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + # Resolve ``api_base`` so deployments don't silently inherit + # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against + # the wrong provider (see ``provider_api_base`` docstring). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base # Add api_version (required for Azure) if config.get("api_version"): diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 4bce79a43..d220aa346 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -28,6 +28,7 @@ from litellm.exceptions import ( BadRequestError as LiteLLMBadRequestError, ContextWindowExceededError, ) +from pydantic import Field from app.utils.perf import get_perf_logger @@ -133,42 +134,14 @@ PROVIDER_MAP = { } -# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when -# a global LLM config does *not* specify ``api_base``: without this, LiteLLM -# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``, -# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku`` -# request to an Azure endpoint, which then 404s with ``Resource not found``. -# Only providers with a well-known, stable public base URL are listed here — -# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, -# huggingface, databricks, cloudflare, replicate) are intentionally omitted -# so their existing config-driven behaviour is preserved. -PROVIDER_DEFAULT_API_BASE = { - "openrouter": "https://openrouter.ai/api/v1", - "groq": "https://api.groq.com/openai/v1", - "mistral": "https://api.mistral.ai/v1", - "perplexity": "https://api.perplexity.ai", - "xai": "https://api.x.ai/v1", - "cerebras": "https://api.cerebras.ai/v1", - "deepinfra": "https://api.deepinfra.com/v1/openai", - "fireworks_ai": "https://api.fireworks.ai/inference/v1", - "together_ai": "https://api.together.xyz/v1", - "anyscale": "https://api.endpoints.anyscale.com/v1", - "cometapi": "https://api.cometapi.com/v1", - "sambanova": "https://api.sambanova.ai/v1", -} - - -# Canonical provider → base URL when a config uses a generic ``openai``-style -# prefix but the ``provider`` field tells us which API it really is -# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but -# each has its own base URL). -PROVIDER_KEY_DEFAULT_API_BASE = { - "DEEPSEEK": "https://api.deepseek.com/v1", - "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - "MOONSHOT": "https://api.moonshot.ai/v1", - "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", - "MINIMAX": "https://api.minimax.io/v1", -} +# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were +# hoisted to ``app.services.provider_api_base`` so vision and image-gen +# call sites can share the exact same defense (OpenRouter / Groq / etc. +# 404-ing against an inherited Azure endpoint). Re-exported here for +# backward compatibility with any external import. +from app.services.provider_api_base import ( # noqa: E402 + resolve_api_base, +) class LLMRouterService: @@ -207,6 +180,12 @@ class LLMRouterService: """ Initialize the router with global LLM configurations. + Configs with ``router_pool_eligible=False`` are skipped so that + dynamic OpenRouter entries stay out of the shared router pool used + by title-gen / sub-agent ``model="auto"`` flows. Those dynamic + entries are still available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + Args: global_configs: List of global LLM config dictionaries from YAML router_settings: Optional router settings (routing_strategy, num_retries, etc.) @@ -220,6 +199,8 @@ class LLMRouterService: model_list = [] premium_models: set[str] = set() for config in global_configs: + if config.get("router_pool_eligible") is False: + continue deployment = cls._config_to_deployment(config) if deployment: model_list.append(deployment) @@ -308,10 +289,45 @@ class LLMRouterService: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None + @classmethod + def rebuild( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + """Reset the router and re-run ``initialize`` with fresh configs. + + ``initialize`` short-circuits once it has run to avoid re-creating the + LiteLLM Router on every request; ``rebuild`` deliberately clears + ``_initialized`` so a caller (e.g. background OpenRouter refresh) + can force the pool to be rebuilt after catalogue changes. + """ + instance = cls.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + cls.initialize(global_configs, router_settings) + @classmethod def is_premium_model(cls, model_string: str) -> bool: - """Return True if *model_string* (as reported by LiteLLM) belongs to a - premium-tier deployment in the router pool.""" + """Return True if *model_string* belongs to a premium-tier deployment + in the LiteLLM router pool. + + Scope: only covers configs with ``router_pool_eligible`` truthy. That + includes static YAML premium configs AND dynamic OpenRouter *premium* + entries (which opt in at generation time). Dynamic OpenRouter *free* + entries are deliberately kept out of the router pool — OpenRouter + enforces free-tier limits globally per account, so per-deployment + router accounting can't represent them correctly — and therefore + return ``False`` here, which matches their ``billing_tier="free"`` + (no premium quota). + + For per-request premium checks on an arbitrary config (static or + dynamic, pool or non-pool), read ``agent_config.is_premium`` instead; + that reflects the per-config ``billing_tier`` directly and is what + user-facing Auto-mode thread pinning uses to bill correctly. + """ instance = cls.get_instance() return model_string in instance._premium_model_strings @@ -422,14 +438,14 @@ class LLMRouterService: # Resolve ``api_base``. Config value wins; otherwise apply a # provider-aware default so the deployment does not silently # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE`` + # requests to the wrong endpoint. See ``provider_api_base`` # docstring for the motivating bug (OpenRouter models 404-ing # against an Azure endpoint). - api_base = config.get("api_base") - if not api_base: - api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider) - if not api_base: - api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix) + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) if api_base: litellm_params["api_base"] = api_base @@ -573,6 +589,11 @@ class ChatLiteLLMRouter(BaseChatModel): # Public attributes that Pydantic will manage model: str = "auto" streaming: bool = True + # Static kwargs that flow through to ``litellm.completion(...)`` on every + # invocation (e.g. ``cache_control_injection_points`` set by + # ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from + # ``invoke()`` still take precedence — see ``_generate``/``_astream``. + model_kwargs: dict[str, Any] = Field(default_factory=dict) # Bound tools and tool choice for tool calling _bound_tools: list[dict] | None = None @@ -898,13 +919,16 @@ class ChatLiteLLMRouter(BaseChatModel): logger.warning(f"Failed to convert tool {tool}: {e}") continue - # Create a new instance with tools bound + # Create a new instance with tools bound. Carry through ``model_kwargs`` + # so static settings (e.g. cache_control_injection_points) survive the + # bind_tools rebuild. return ChatLiteLLMRouter( router=self._router, bound_tools=formatted_tools if formatted_tools else None, tool_choice=tool_choice, model=self.model, streaming=self.streaming, + model_kwargs=dict(self.model_kwargs), **kwargs, ) @@ -929,8 +953,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -997,8 +1023,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -1060,8 +1088,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -1110,8 +1140,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 942a9b7af..ade202c72 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -16,6 +16,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -496,8 +497,14 @@ async def get_vision_llm( - Auto mode (ID 0): VisionLLMRouterService - Global (negative ID): YAML configs - DB (positive ID): VisionLLMConfig table + + Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` + so each ``ainvoke`` debits the search-space owner's premium credit + pool. User-owned BYOK configs and free global configs are returned + unwrapped — they don't consume premium credit (issue M). """ from app.db import VisionLLMConfig + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( VISION_PROVIDER_MAP, VisionLLMRouterService, @@ -519,6 +526,8 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None + owner_user_id = search_space.user_id + if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -526,6 +535,13 @@ async def get_vision_llm( ) return None try: + # Auto mode is currently treated as free at the wrapper + # level — the underlying router can dispatch to either + # premium or free YAML configs but routing decisions are + # opaque. If/when we want to bill Auto-routed vision + # calls we'd need to thread the resolved deployment's + # billing_tier back from the router. For now we keep + # parity with chat Auto, which also doesn't pre-classify. return ChatLiteLLMRouter( router=VisionLLMRouterService.get_router(), streaming=True, @@ -541,29 +557,46 @@ async def get_vision_llm( return None if global_cfg.get("custom_provider"): - model_string = ( - f"{global_cfg['custom_provider']}/{global_cfg['model_name']}" - ) + provider_prefix = global_cfg["custom_provider"] + model_string = f"{provider_prefix}/{global_cfg['model_name']}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( global_cfg["provider"].upper(), global_cfg["provider"].lower(), ) - model_string = f"{prefix}/{global_cfg['model_name']}" + model_string = f"{provider_prefix}/{global_cfg['model_name']}" litellm_kwargs = { "model": model_string, "api_key": global_cfg["api_key"], } - if global_cfg.get("api_base"): - litellm_kwargs["api_base"] = global_cfg["api_base"] + api_base = resolve_api_base( + provider=global_cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=global_cfg.get("api_base"), + ) + if api_base: + litellm_kwargs["api_base"] = api_base if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.new_chat.llm_config import SanitizedChatLiteLLM - return SanitizedChatLiteLLM(**litellm_kwargs) + inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) + billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + if billing_tier == "premium": + return QuotaCheckedVisionLLM( + inner_llm, + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=model_string, + quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + ) + return inner_llm + + # User-owned (positive ID) BYOK configs — always free. result = await session.execute( select(VisionLLMConfig).where( VisionLLMConfig.id == config_id, @@ -578,20 +611,26 @@ async def get_vision_llm( return None if vision_cfg.custom_provider: - model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}" + provider_prefix = vision_cfg.custom_provider + model_string = f"{provider_prefix}/{vision_cfg.model_name}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( vision_cfg.provider.value.upper(), vision_cfg.provider.value.lower(), ) - model_string = f"{prefix}/{vision_cfg.model_name}" + model_string = f"{provider_prefix}/{vision_cfg.model_name}" litellm_kwargs = { "model": model_string, "api_key": vision_cfg.api_key, } - if vision_cfg.api_base: - litellm_kwargs["api_base"] = vision_cfg.api_base + api_base = resolve_api_base( + provider=vision_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=vision_cfg.api_base, + ) + if api_base: + litellm_kwargs["api_base"] = api_base if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 3531d37af..55129668c 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,20 +565,31 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str) -> str: + def format_error( + self, + error_text: str, + error_code: str | None = None, + extra: dict[str, object] | None = None, + ) -> str: """ Format an error message. Args: error_text: The error message text + error_code: Optional machine-readable error code for frontend branching Returns: str: SSE formatted error part Example output: - data: {"type":"error","errorText":"Something went wrong"} + data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - return self._format_sse({"type": "error", "errorText": error_text}) + payload: dict[str, object] = {"type": "error", "errorText": error_text} + if error_code: + payload["errorCode"] = error_code + if extra: + payload.update(extra) + return self._format_sse(payload) # ========================================================================= # Tool Parts diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 1245f73aa..6454e2d58 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path. """ import asyncio +import hashlib import logging import threading +import time from typing import Any import httpx +from app.services.quality_score import ( + _HEALTH_BLEND_WEIGHT, + _HEALTH_ENRICH_CONCURRENCY, + _HEALTH_ENRICH_TOP_N_FREE, + _HEALTH_ENRICH_TOP_N_PREMIUM, + _HEALTH_FAIL_RATIO_FALLBACK, + _HEALTH_FETCH_TIMEOUT_SEC, + aggregate_health, + static_score_or, +) + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +OPENROUTER_ENDPOINTS_URL_TEMPLATE = ( + "https://openrouter.ai/api/v1/models/{model_id}/endpoints" +) # Sentinel value stored on each generated config so we can distinguish # dynamic OpenRouter entries from hand-written YAML entries during refresh. _OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__" +# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides +# enough headroom to avoid frequent collisions for OpenRouter's catalogue +# (~300 models) while keeping IDs comfortably within Postgres INTEGER range. +_STABLE_ID_HASH_WIDTH = 9_000_000 + + +def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int: + """Derive a deterministic negative config ID from ``model_id``. + + The same ``model_id`` always hashes to the same base value so thread pins + survive catalogue churn (models appearing/disappearing/reordering between + refreshes). On collision we decrement until we find an unused slot; this + keeps the mapping stable for the first config that claimed a slot and + only shifts collisions, which is much less disruptive than the legacy + index-based scheme that reshuffled every ID when the catalogue changed. + """ + digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest() + base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH) + cid = base + while cid in taken: + cid -= 1 + taken.add(cid) + return cid + + +def _openrouter_tier(model: dict) -> str: + """Classify an OpenRouter model as ``"free"`` or ``"premium"``. + + Per OpenRouter's API contract, a model is free if: + - Its id ends with ``:free`` (OpenRouter's own free-variant convention), or + - Both ``pricing.prompt`` and ``pricing.completion`` are zero strings. + + Anything else (missing pricing, non-zero pricing) falls through to + ``"premium"`` so we never under-charge users. This derivation runs off the + already-cached /api/v1/models payload, so it adds no network cost. + """ + if model.get("id", "").endswith(":free"): + return "free" + pricing = model.get("pricing") or {} + prompt = str(pricing.get("prompt", "")).strip() + completion = str(pricing.get("completion", "")).strip() + if prompt == "0" and completion == "0": + return "free" + return "premium" + def _is_text_output_model(model: dict) -> bool: """Return True if the model produces text output only (skip image/audio generators).""" @@ -32,6 +93,53 @@ def _is_text_output_model(model: dict) -> bool: return output_mods == ["text"] +def _is_image_output_model(model: dict) -> bool: + """Return True if the model can produce image output. + + OpenRouter's ``architecture.output_modalities`` is a list (e.g. + ``["image"]`` for pure image generators, ``["text", "image"]`` for + multi-modal generators that also emit captions). We accept any model + that can output images; the call site decides whether to use the + image-generation API or chat completion. + """ + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def _is_vision_input_model(model: dict) -> bool: + """Return True if the model can ingest an image AND emit text. + + OpenRouter's ``architecture.input_modalities`` lists what the model + accepts; ``output_modalities`` lists what it produces. A vision LLM + is a model that takes images in and produces text out — i.e. it can + answer questions about a screenshot or extract content from an + image. Pure image-to-image models (e.g. style transfer) and + text-only models are excluded. + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + output_mods = arch.get("output_modalities", []) or [] + return "image" in input_mods and "text" in output_mods + + +def _supports_image_input(model: dict) -> bool: + """Return True if the model accepts ``image`` in its input modalities. + + Differs from :func:`_is_vision_input_model` in that it does NOT + require text output — chat-tab models always emit text already (the + chat catalog filters by ``_is_text_output_model``), so the only + extra capability we need to track per chat config is whether the + model can ingest user-attached images. The chat selector and the + streaming task both key off this flag to prevent hitting an + OpenRouter 404 ``"No endpoints found that support image input"`` + when the user uploads an image and selects a text-only model + (DeepSeek V3, Llama 3.x base, etc.). + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + return "image" in input_mods + + def _supports_tool_calling(model: dict) -> bool: """Return True if the model supports function/tool calling.""" supported = model.get("supported_parameters") or [] @@ -56,6 +164,11 @@ _EXCLUDED_MODEL_IDS: set[str] = { # Deep-research models reject standard params (temperature, etc.) "openai/o3-deep-research", "openai/o4-mini-deep-research", + # OpenRouter's own meta-router over free models. We already enumerate every + # concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread + # pinning handles churn via the repair path, so exposing an additional + # indirection layer would only duplicate the capability with an opaque slug. + "openrouter/free", } _EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) @@ -109,24 +222,71 @@ async def _fetch_models_async() -> list[dict] | None: return None +def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]: + """Return a ``{model_id: {"prompt": str, "completion": str}}`` map. + + Pricing values are kept as the raw OpenRouter strings (e.g. + ``"0.000003"``); ``pricing_registration`` converts them to floats + when registering with LiteLLM. Models with missing or malformed + pricing are simply omitted — operator-side risk if any of those are + premium. + """ + pricing: dict[str, dict[str, str]] = {} + for model in raw_models: + model_id = str(model.get("id") or "").strip() + if not model_id: + continue + p = model.get("pricing") or {} + prompt = p.get("prompt") + completion = p.get("completion") + if prompt is None and completion is None: + continue + pricing[model_id] = { + "prompt": str(prompt) if prompt is not None else "", + "completion": str(completion) if completion is not None else "", + } + return pricing + + def _generate_configs( raw_models: list[dict], settings: dict[str, Any], ) -> list[dict]: - """ - Convert raw OpenRouter model entries into global LLM config dicts. + """Convert raw OpenRouter model entries into global LLM config dicts. - Models are sorted by ID for deterministic, stable ID assignment across - restarts and refreshes. + Tier (``billing_tier``) is derived per-model from OpenRouter's own API + signals via ``_openrouter_tier`` — there is no longer a uniform YAML + override. Config IDs are derived via ``_stable_config_id`` so they + survive catalogue churn across refreshes. + + Router-pool membership is tier-aware: + + - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``) + so sub-agent ``model="auto"`` flows benefit from load balancing and + failover across the curated YAML configs and the OR premium passthrough. + - Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM + Router tracks rate limits per deployment, but OpenRouter enforces a + single global free-tier quota (~20 RPM + 50-1000 daily requests + account-wide across every ``:free`` model), so rotating across many + free deployments would only burn the shared bucket faster. Free OR + models remain fully available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + + OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream + via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer + because our own Auto (Fastest) pin + 24 h refresh + repair logic already + cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) api_key: str = settings.get("api_key", "") - billing_tier: str = settings.get("billing_tier", "premium") - anonymous_enabled: bool = settings.get("anonymous_enabled", False) seo_enabled: bool = settings.get("seo_enabled", False) quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1000000) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + anon_paid: bool = settings.get("anonymous_enabled_paid", False) + anon_free: bool = settings.get("anonymous_enabled_free", False) litellm_params: dict = settings.get("litellm_params") or {} system_instructions: str = settings.get("system_instructions", "") use_default: bool = settings.get("use_default_system_instructions", True) @@ -142,19 +302,24 @@ def _generate_configs( and _is_allowed_model(m) and "/" in m.get("id", "") ] - text_models.sort(key=lambda m: m["id"]) configs: list[dict] = [] - for idx, model in enumerate(text_models): + taken: set[int] = set() + now_ts = int(time.time()) + + for model in text_models: model_id: str = model["id"] name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + + static_q = static_score_or(model, now_ts=now_ts) cfg: dict[str, Any] = { - "id": id_offset - idx, + "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter", - "billing_tier": billing_tier, - "anonymous_enabled": anonymous_enabled, + "billing_tier": tier, + "anonymous_enabled": anon_free if tier == "free" else anon_paid, "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, @@ -162,12 +327,199 @@ def _generate_configs( "model_name": model_id, "api_key": api_key, "api_base": "", - "rpm": rpm, - "tpm": tpm, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), "system_instructions": system_instructions, "use_default_system_instructions": use_default, "citations_enabled": citations_enabled, + # Premium OR deployments join the LiteLLM router pool so sub-agent + # model="auto" flows can load-balance / fail over across them. + # Free OR deployments stay out: OpenRouter's free tier is a single + # account-wide quota, so per-deployment routing can't spread load + # there — it just drains the shared bucket faster. + "router_pool_eligible": tier == "premium", + # Capability flag derived from ``architecture.input_modalities``. + # Read by the new-chat selector to dim image-incompatible models + # when the user has pending image attachments, and by + # ``stream_new_chat`` as a fail-fast safety net before the + # OpenRouter request would otherwise 404 with + # ``"No endpoints found that support image input"``. + "supports_image_input": _supports_image_input(model), + _OPENROUTER_DYNAMIC_MARKER: True, + # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # to the static score and gets re-blended with health on the next + # ``_enrich_health`` pass (synchronous on refresh, deferred on cold + # start so startup latency is unchanged). + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_q, + "quality_score_health": None, + "quality_score": static_q, + "health_gated": False, + } + configs.append(cfg) + + return configs + + +# ID-offset bands used to keep dynamic OpenRouter configs in their own +# namespace per surface. Image / vision get separate bands so a single +# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. +_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 +_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 + + +def _generate_image_gen_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter image-generation models into global image-gen + config dicts (matches the YAML shape consumed by ``image_generation_routes``). + + Filter: + - architecture.output_modalities contains "image" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Notably we *drop* the chat-only filters (``_supports_tool_calling`` and + ``_has_sufficient_context``) because tool calls and context windows are + irrelevant for the ``aimage_generation`` API. ``billing_tier`` is + derived per model the same way as chat (``_openrouter_tier``). + + Cost is intentionally *not* registered with LiteLLM at startup + (``pricing_registration`` skips image gen): OpenRouter image-gen + models are not in LiteLLM's native cost map and OpenRouter populates + ``response_cost`` directly from the response header. A defensive + branch in ``_extract_cost_usd`` handles the rare case where + ``usage.cost`` is missing — see ``token_tracking_service``. + """ + id_offset: int = int( + settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + free_rpm: int = settings.get("free_rpm", 20) + litellm_params: dict = settings.get("litellm_params") or {} + + image_models = [ + m + for m in raw_models + if _is_image_output_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in image_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (image generation)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` and 404 on + # ``image_generation/transformation`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + +def _generate_vision_llm_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter vision-capable LLMs into global vision-LLM config + dicts (matches the YAML shape consumed by ``vision_llm_routes``). + + Filter: + - architecture.input_modalities contains "image" + - architecture.output_modalities contains "text" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Vision-LLM is invoked from the indexer (image extraction during + document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so + the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` + filters do not apply: a small-context vision model that doesn't + advertise tool-calling is still perfectly viable for "describe this + image" prompts. + """ + id_offset: int = int( + settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) + litellm_params: dict = settings.get("litellm_params") or {} + + vision_models = [ + m + for m in raw_models + if _is_vision_input_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in vision_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + pricing = model.get("pricing") or {} + + # Capture per-token prices so ``pricing_registration`` can + # register them with LiteLLM at startup (and so the cost + # estimator in ``estimate_call_reserve_micros`` can resolve + # them at reserve time). + try: + input_cost = float(pricing.get("prompt", 0) or 0) + except (TypeError, ValueError): + input_cost = 0.0 + try: + output_cost = float(pricing.get("completion", 0) or 0) + except (TypeError, ValueError): + output_cost = 0.0 + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (vision)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + "quota_reserve_tokens": quota_reserve_tokens, + "input_cost_per_token": input_cost or None, + "output_cost_per_token": output_cost or None, _OPENROUTER_DYNAMIC_MARKER: True, } configs.append(cfg) @@ -187,6 +539,25 @@ class OpenRouterIntegrationService: self._configs_by_id: dict[int, dict] = {} self._initialized = False self._refresh_task: asyncio.Task | None = None + # Last-good per-model health snapshot. Survives across refresh + # cycles so a transient OpenRouter /endpoints outage doesn't drop + # every cfg back to static-only scoring. + # Shape: {model_name: {"gated": bool, "score": float | None}} + self._health_cache: dict[str, dict[str, Any]] = {} + self._enrich_task: asyncio.Task | None = None + # Raw OpenRouter pricing per model_id, captured at the same time + # we generate configs. Consumed by ``pricing_registration`` to + # teach LiteLLM the per-token cost of every dynamic deployment so + # the success-callback can populate ``response_cost`` correctly. + self._raw_pricing: dict[str, dict[str, str]] = {} + # Cached raw catalogue from the most recent fetch. Image / vision + # emitters reuse this to avoid a second network call per surface. + self._raw_models: list[dict] = [] + # Image / vision config caches (only populated when the matching + # opt-in flag is true on initialize). Refreshed in lockstep with + # the chat catalogue. + self._image_configs: list[dict] = [] + self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -216,16 +587,55 @@ class OpenRouterIntegrationService: self._initialized = True return [] + self._raw_models = raw_models self._configs = _generate_configs(raw_models, settings) self._configs_by_id = {c["id"]: c for c in self._configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + + # Populate image / vision caches when their opt-in flag is set. + # Empty otherwise so the accessors return [] without re-running + # filters every refresh. + if settings.get("image_generation_enabled"): + self._image_configs = _generate_image_gen_configs(raw_models, settings) + logger.info( + "OpenRouter integration: image-gen emission ON (%d models)", + len(self._image_configs), + ) + else: + self._image_configs = [] + + if settings.get("vision_enabled"): + self._vision_configs = _generate_vision_llm_configs(raw_models, settings) + logger.info( + "OpenRouter integration: vision LLM emission ON (%d models)", + len(self._vision_configs), + ) + else: + self._vision_configs = [] + self._initialized = True + tier_counts = self._tier_counts(self._configs) logger.info( - "OpenRouter integration: loaded %d models (IDs %d to %d)", + "OpenRouter integration: loaded %d models (free=%d, premium=%d)", len(self._configs), - self._configs[0]["id"] if self._configs else 0, - self._configs[-1]["id"] if self._configs else 0, + tier_counts["free"], + tier_counts["premium"], ) + + # Schedule the first health-enrichment pass as a deferred task so + # cold-start latency is unchanged. Only valid when an event loop is + # already running (e.g. FastAPI lifespan); Celery worker init is + # fully sync so we silently skip — its first refresh tick (or the + # next refresh from the web process) will populate health data. + try: + loop = asyncio.get_running_loop() + self._enrich_task = loop.create_task( + self._enrich_health_safely(self._configs) + ) + except RuntimeError: + pass + return self._configs # ------------------------------------------------------------------ @@ -241,6 +651,8 @@ class OpenRouterIntegrationService: new_configs = _generate_configs(raw_models, self._settings) new_by_id = {c["id"]: c for c in new_configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + self._raw_models = raw_models from app.config import config as app_config @@ -254,7 +666,263 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - logger.info("OpenRouter refresh: updated to %d models", len(new_configs)) + # Image / vision lists are atomic-swapped the same way: filter out + # the previous dynamic entries from the live config list and append + # the freshly generated ones. No-ops when the opt-in flag is off. + if self._settings.get("image_generation_enabled"): + new_image = _generate_image_gen_configs(raw_models, self._settings) + static_image = [ + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image + self._image_configs = new_image + + if self._settings.get("vision_enabled"): + new_vision = _generate_vision_llm_configs(raw_models, self._settings) + static_vision = [ + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision + self._vision_configs = new_vision + + # Catalogue churn invalidates per-config "recently healthy" credit + # earned by the previous turn's preflight. Drop the whole table so + # the next turn re-probes against the freshly loaded configs. + try: + from app.services.auto_model_pin_service import clear_healthy + + clear_healthy() + except Exception: + logger.debug( + "OpenRouter refresh: clear_healthy import skipped", exc_info=True + ) + + tier_counts = self._tier_counts(new_configs) + logger.info( + "OpenRouter refresh: updated to %d models (free=%d, premium=%d)", + len(new_configs), + tier_counts["free"], + tier_counts["premium"], + ) + + # Re-blend health scores against the freshly fetched catalogue. Also + # re-stamps health for any YAML-curated cfg with provider==OPENROUTER + # so a hand-picked dead OR model is gated like a dynamic one. + await self._enrich_health_safely(static_configs + new_configs, log_summary=True) + + # Re-register LiteLLM pricing for the freshly fetched catalogue + # so newly added OR models bill correctly on their first call. + # Runs before the router rebuild because the router may issue + # cost-table lookups during deployment registration. + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as exc: + logger.warning( + "OpenRouter refresh: pricing re-registration skipped (%s)", exc + ) + + # Rebuild the LiteLLM router so freshly fetched configs flow through + # (dynamic OR premium entries now opt into the pool, free ones stay + # out; a refresh also needs to pick up any static-config edits and + # reset cached context-window profiles). + try: + from app.config import config as _app_config + from app.services.llm_router_service import ( + LLMRouterService, + _router_instance_cache as _chat_router_cache, + ) + + LLMRouterService.rebuild( + _app_config.GLOBAL_LLM_CONFIGS, + getattr(_app_config, "ROUTER_SETTINGS", None), + ) + _chat_router_cache.clear() + except Exception as exc: + logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc) + + @staticmethod + def _tier_counts(configs: list[dict]) -> dict[str, int]: + counts = {"free": 0, "premium": 0} + for cfg in configs: + tier = str(cfg.get("billing_tier", "")).lower() + if tier in counts: + counts[tier] += 1 + return counts + + # ------------------------------------------------------------------ + # Auto (Fastest) health enrichment + # ------------------------------------------------------------------ + + async def _enrich_health_safely( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Wrapper around ``_enrich_health`` that swallows all errors. + + Health enrichment is best-effort: any failure must leave cfgs in + their static-only state and never break refresh / startup. + """ + try: + await self._enrich_health(configs, log_summary=log_summary) + except Exception: + logger.exception("OpenRouter health enrichment failed") + + async def _enrich_health( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Fetch per-model ``/endpoints`` data for the top OR cfgs and blend + the resulting health score into ``cfg["quality_score"]``. + + Bounded fan-out: top-N per tier by ``quality_score_static`` only, + with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the + outbound HTTP. Misses fall back to a per-model last-good cache; if + the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep + the entire previous cycle's cache for this run. + """ + or_cfgs = [ + c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" + ] + if not or_cfgs: + return + + premium_pool = sorted( + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_PREMIUM] + free_pool = sorted( + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_FREE] + # De-duplicate while preserving order: a cfg shouldn't fall in both + # tiers, but defensive code is cheap here. + seen_ids: set[int] = set() + selected: list[dict] = [] + for cfg in premium_pool + free_pool: + cid = int(cfg.get("id", 0)) + if cid in seen_ids: + continue + seen_ids.add(cid) + selected.append(cfg) + + if not selected: + return + + api_key = str(self._settings.get("api_key") or "") + semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY) + + async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client: + results = await asyncio.gather( + *( + self._fetch_endpoints(client, semaphore, api_key, cfg) + for cfg in selected + ) + ) + + fail_count = sum(1 for _, _, err in results if err is not None) + fail_ratio = fail_count / len(results) if results else 0.0 + degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK + if degraded: + logger.warning( + "auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d " + "using_last_good_cache=true", + fail_ratio, + len(results), + ) + + # Per-cfg health update. + for cfg, endpoints, err in results: + model_name = str(cfg.get("model_name", "")) + if not degraded and err is None and endpoints is not None: + gated, h_score = aggregate_health(endpoints) + cfg["health_gated"] = bool(gated) + cfg["quality_score_health"] = h_score + self._health_cache[model_name] = { + "gated": bool(gated), + "score": h_score, + } + else: + cached = self._health_cache.get(model_name) + if cached is not None: + cfg["health_gated"] = bool(cached.get("gated", False)) + cfg["quality_score_health"] = cached.get("score") + # else: keep current values (initial defaults from + # _generate_configs / load_global_llm_configs). + + # Blend health into the final score for every OR cfg, including + # those outside the enriched top-N (they fall through to static). + gated_count = 0 + by_provider: dict[str, int] = {} + for cfg in or_cfgs: + static_q = int(cfg.get("quality_score_static") or 0) + h = cfg.get("quality_score_health") + if h is not None and not cfg.get("health_gated"): + blended = ( + _HEALTH_BLEND_WEIGHT * float(h) + + (1 - _HEALTH_BLEND_WEIGHT) * static_q + ) + cfg["quality_score"] = round(blended) + else: + cfg["quality_score"] = static_q + + if cfg.get("health_gated"): + gated_count += 1 + model_id = str(cfg.get("model_name", "")) + provider_slug = ( + model_id.split("/", 1)[0] if "/" in model_id else "unknown" + ) + by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1 + + if log_summary: + logger.info( + "auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f " + "total_enriched=%d", + gated_count, + dict(sorted(by_provider.items(), key=lambda kv: -kv[1])), + fail_ratio, + len(selected), + ) + + @staticmethod + async def _fetch_endpoints( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + api_key: str, + cfg: dict, + ) -> tuple[dict, list[dict] | None, Exception | None]: + """Fetch ``/api/v1/models/{id}/endpoints`` for one cfg. + + Returns ``(cfg, endpoints, err)`` so the caller can keep batched + results aligned with their cfgs without raising. + """ + model_id = str(cfg.get("model_name", "")) + if not model_id: + return cfg, None, ValueError("missing model_name") + + url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + async with semaphore: + try: + resp = await client.get(url, headers=headers) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + return cfg, None, exc + + payload = data.get("data") if isinstance(data, dict) else None + if not isinstance(payload, dict): + return cfg, None, ValueError("malformed endpoints payload") + endpoints = payload.get("endpoints") + if not isinstance(endpoints, list): + return cfg, [], None + return cfg, endpoints, None async def _refresh_loop(self, interval_hours: float) -> None: interval_sec = interval_hours * 3600 @@ -289,3 +957,34 @@ class OpenRouterIntegrationService: def get_config_by_id(self, config_id: int) -> dict | None: return self._configs_by_id.get(config_id) + + def get_image_generation_configs(self) -> list[dict]: + """Return the dynamic OpenRouter image-generation configs (empty + list when the ``image_generation_enabled`` flag is off). + + Each entry already has ``billing_tier`` derived per-model from + OpenRouter's signals and is shaped to drop directly into + ``Config.GLOBAL_IMAGE_GEN_CONFIGS``. + """ + return list(self._image_configs) + + def get_vision_llm_configs(self) -> list[dict]: + """Return the dynamic OpenRouter vision-LLM configs (empty list + when the ``vision_enabled`` flag is off). + + Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` + so ``pricing_registration`` can teach LiteLLM the cost of these + models the same way it does for chat — which keeps the billable + wrapper able to debit accurate micro-USD on a vision call. + """ + return list(self._vision_configs) + + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + """Return the cached raw OpenRouter pricing map. + + Shape: ``{model_id: {"prompt": str, "completion": str}}``. The + values are the strings OpenRouter publishes (USD per token), + never converted to floats here so the caller can decide how to + handle malformed or unset entries. + """ + return dict(self._raw_pricing) diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py new file mode 100644 index 000000000..de98e50c2 --- /dev/null +++ b/surfsense_backend/app/services/pricing_registration.py @@ -0,0 +1,274 @@ +""" +Pricing registration with LiteLLM. + +Many models reach our LiteLLM callback without LiteLLM knowing their +per-token cost — namely: + +* The ~300 dynamic OpenRouter deployments (their pricing only lives on + OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published + pricing table). +* Static YAML deployments whose ``base_model`` name is operator-defined + (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore + not in LiteLLM's table either. + +Without registration, ``kwargs["response_cost"]`` is 0 for those calls +and the user gets billed nothing — a fail-safe but wrong answer for a +cost-based credit system. This module runs once at startup, after the +OpenRouter integration has fetched its catalogue, and registers each +known model's pricing with ``litellm.register_model()`` under multiple +plausible alias keys (LiteLLM's cost lookup may use any of them +depending on whether the call went through the Router, ChatLiteLLM, +or a direct ``acompletion``). + +Operators who run a custom Azure deployment whose ``base_model`` name +isn't in LiteLLM's table can declare per-token pricing inline in +``global_llm_config.yaml`` via ``input_cost_per_token`` and +``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without +that declaration the model's calls debit 0 — never overbilled. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import litellm + +logger = logging.getLogger(__name__) + + +def _safe_float(value: Any) -> float: + """Return ``float(value)`` if it parses to a positive number, else 0.0.""" + if value is None: + return 0.0 + try: + f = float(value) + except (TypeError, ValueError): + return 0.0 + return f if f > 0 else 0.0 + + +def _alias_set_for_openrouter(model_id: str) -> list[str]: + """Return the alias keys to register an OpenRouter model under. + + LiteLLM's cost-callback lookup key varies by call path: + - Router with ``model="openrouter/X"`` → kwargs["model"] is + typically ``openrouter/X``. + - LiteLLM's own provider routing may strip the prefix and pass the + bare ``X`` to the cost-table lookup. + Registering under both keeps the lookup hermetic regardless of + which path the call took. + """ + aliases = [f"openrouter/{model_id}", model_id] + return list(dict.fromkeys(a for a in aliases if a)) + + +def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]: + """Return the alias keys to register a static YAML deployment under. + + Same reasoning as the OpenRouter set: cover the bare ``base_model``, + the ``/`` form LiteLLM Router constructs, and the + bare ``model_name`` because callbacks sometimes see whichever was + configured first. + """ + provider_lower = (provider or "").lower() + aliases: list[str] = [] + if base_model: + aliases.append(base_model) + if provider_lower and base_model: + aliases.append(f"{provider_lower}/{base_model}") + if model_name and model_name != base_model: + aliases.append(model_name) + if provider_lower and model_name and model_name != base_model: + aliases.append(f"{provider_lower}/{model_name}") + # Azure deployments often surface as "azure/"; normalise the + # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``. + if provider_lower == "azure_openai": + if base_model: + aliases.append(f"azure/{base_model}") + if model_name and model_name != base_model: + aliases.append(f"azure/{model_name}") + return list(dict.fromkeys(a for a in aliases if a)) + + +def _register( + aliases: list[str], + *, + input_cost: float, + output_cost: float, + provider: str, + mode: str = "chat", +) -> int: + """Register a single pricing entry under every alias in ``aliases``. + + Returns the count of aliases successfully registered. + """ + payload: dict[str, dict[str, Any]] = {} + for alias in aliases: + payload[alias] = { + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost, + "litellm_provider": provider, + "mode": mode, + } + if not payload: + return 0 + try: + litellm.register_model(payload) + except Exception as exc: + logger.warning( + "[PricingRegistration] register_model failed for aliases=%s: %s", + aliases, + exc, + ) + return 0 + return len(payload) + + +def _register_chat_shape_configs( + configs: list[dict], + *, + or_pricing: dict[str, dict[str, str]], + label: str, +) -> tuple[int, int, int, list[str]]: + """Common loop that registers per-token pricing for a list of "chat-shape" + configs (chat or vision LLM — both use ``input_cost_per_token`` / + ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape). + + Returns ``(registered_models, registered_aliases, skipped, sample_keys)``. + """ + registered_models = 0 + registered_aliases = 0 + skipped_no_pricing = 0 + sample_keys: list[str] = [] + + for cfg in configs: + provider = str(cfg.get("provider") or "").upper() + model_name = str(cfg.get("model_name") or "").strip() + litellm_params = cfg.get("litellm_params") or {} + base_model = str(litellm_params.get("base_model") or model_name).strip() + + if provider == "OPENROUTER": + entry = or_pricing.get(model_name) + if entry: + input_cost = _safe_float(entry.get("prompt")) + output_cost = _safe_float(entry.get("completion")) + else: + # Vision configs from ``_generate_vision_llm_configs`` + # carry their pricing inline because the OpenRouter + # raw-pricing cache is keyed by chat-catalogue model_id; + # vision flows pick up the inline values here. + input_cost = _safe_float(cfg.get("input_cost_per_token")) + output_cost = _safe_float(cfg.get("output_cost_per_token")) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_openrouter(model_name) + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider="openrouter", + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + continue + + input_cost = _safe_float( + cfg.get("input_cost_per_token") + or litellm_params.get("input_cost_per_token") + ) + output_cost = _safe_float( + cfg.get("output_cost_per_token") + or litellm_params.get("output_cost_per_token") + ) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_yaml(provider, model_name, base_model) + provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider=provider_slug, + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + + logger.info( + "[PricingRegistration:%s] registered pricing for %d models (%d aliases); " + "%d configs had no pricing data; sample registered keys=%s", + label, + registered_models, + registered_aliases, + skipped_no_pricing, + sample_keys, + ) + return registered_models, registered_aliases, skipped_no_pricing, sample_keys + + +def register_pricing_from_global_configs() -> None: + """Register pricing for every known LLM deployment with LiteLLM. + + Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` + so vision calls (during indexing) can resolve cost the same way chat + calls do — namely: + + 1. ``OPENROUTER``: pulls the cached raw pricing from + ``OpenRouterIntegrationService`` (populated during its own + startup fetch) and converts the per-token strings to floats. For + vision configs that carry pricing inline (``input_cost_per_token`` / + ``output_cost_per_token`` set on the cfg itself) we fall back to + those values when the OR cache misses the model. + 2. Anything else: looks for operator-declared + ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML + config block (top-level or nested under ``litellm_params``). + + **Image generation is intentionally NOT registered here.** The cost + shape for image-gen is per-image (``output_cost_per_image``), not + per-token, and LiteLLM's ``register_model`` doesn't accept those + keys via the chat-cost path. OpenRouter image-gen models populate + ``response_cost`` directly from their response header instead, and + Azure-native image-gen models are already in LiteLLM's cost map. + + Calls without a resolved pair of costs are skipped, not registered + with zeros — operators who forget pricing get a "$0 debit" warning + in ``TokenTrackingCallback`` rather than silently overwriting any + pricing LiteLLM might know natively. + """ + from app.config import config as app_config + + chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) + vision_configs: list[dict] = list( + getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] + ) + if not chat_configs and not vision_configs: + logger.info("[PricingRegistration] no global configs to register") + return + + or_pricing: dict[str, dict[str, str]] = {} + try: + from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, + ) + + if OpenRouterIntegrationService.is_initialized(): + or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing() + except Exception as exc: + logger.debug( + "[PricingRegistration] OpenRouter pricing not available yet: %s", exc + ) + + if chat_configs: + _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") + if vision_configs: + _register_chat_shape_configs( + vision_configs, or_pricing=or_pricing, label="vision" + ) diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py new file mode 100644 index 000000000..dca1f9462 --- /dev/null +++ b/surfsense_backend/app/services/provider_api_base.py @@ -0,0 +1,106 @@ +"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. + +LiteLLM falls back to the module-global ``litellm.api_base`` when an +individual call doesn't pass one, which silently inherits provider-agnostic +env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an +explicit ``api_base``, an ``openrouter/`` request can end up at an +Azure endpoint and 404 with ``Resource not found`` (real reproducer: +[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends +``/chat/completions`` to whatever inherited base it gets, regardless of +provider). + +The chat router has had this defense for a while +(``llm_router_service.py:466-478``). This module hoists the maps + cascade +into a tiny standalone helper so vision and image-gen can share the same +source of truth without an inter-service circular import. +""" + +from __future__ import annotations + +PROVIDER_DEFAULT_API_BASE: dict[str, str] = { + "openrouter": "https://openrouter.ai/api/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "perplexity": "https://api.perplexity.ai", + "xai": "https://api.x.ai/v1", + "cerebras": "https://api.cerebras.ai/v1", + "deepinfra": "https://api.deepinfra.com/v1/openai", + "fireworks_ai": "https://api.fireworks.ai/inference/v1", + "together_ai": "https://api.together.xyz/v1", + "anyscale": "https://api.endpoints.anyscale.com/v1", + "cometapi": "https://api.cometapi.com/v1", + "sambanova": "https://api.sambanova.ai/v1", +} +"""Default ``api_base`` per LiteLLM provider prefix (lowercase). + +Only providers with a well-known, stable public base URL are listed — +self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, +huggingface, databricks, cloudflare, replicate) are intentionally omitted +so their existing config-driven behaviour is preserved.""" + + +PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { + "DEEPSEEK": "https://api.deepseek.com/v1", + "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "MOONSHOT": "https://api.moonshot.ai/v1", + "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", + "MINIMAX": "https://api.minimax.io/v1", +} +"""Canonical provider key (uppercase) → base URL. + +Used when the LiteLLM provider prefix is the generic ``openai`` shim but the +config's ``provider`` field tells us which API it actually is (DeepSeek, +Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each +has its own base URL).""" + + +def resolve_api_base( + *, + provider: str | None, + provider_prefix: str | None, + config_api_base: str | None, +) -> str | None: + """Resolve a non-Azure-leaking ``api_base`` for a deployment. + + Cascade (first non-empty wins): + 1. The config's own ``api_base`` (whitespace-only treated as missing). + 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. + 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. + 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM + provider integration apply its own default (e.g. AzureOpenAI's + deployment-derived URL, custom provider's per-deployment URL). + + Args: + provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, + ``"DEEPSEEK"``). Case-insensitive. + provider_prefix: The LiteLLM model-string prefix the same call + site builds for the model id (e.g. ``"openrouter"``, + ``"groq"``). Case-insensitive. + config_api_base: ``api_base`` from the global YAML / DB row / + OpenRouter dynamic config. Empty / whitespace-only means + "missing" — the resolver still applies the cascade. + + Returns: + A URL string, or ``None`` if no default applies for this provider. + """ + if config_api_base and config_api_base.strip(): + return config_api_base + + if provider: + key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) + if key_default: + return key_default + + if provider_prefix: + prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) + if prefix_default: + return prefix_default + + return None + + +__all__ = [ + "PROVIDER_DEFAULT_API_BASE", + "PROVIDER_KEY_DEFAULT_API_BASE", + "resolve_api_base", +] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py new file mode 100644 index 000000000..e9a1c33e1 --- /dev/null +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -0,0 +1,280 @@ +"""Capability resolution shared by chat / image / vision call sites. + +Why this exists +--------------- +The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a +single, authoritative answer to one question: *can this chat config accept +``image_url`` content blocks?* Without it, the new-chat selector can't badge +incompatible models and the streaming task can't fail fast with a friendly +error before sending an image to a text-only provider. + +Two functions, two intents: + +- :func:`derive_supports_image_input` — best-effort *True* for catalog and + UI surfacing. Default-allow: an unknown / unmapped model is treated as + capable so we never lock the user out of a freshly added or + third-party-hosted vision model. + +- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming + task's safety net. Returns True only when LiteLLM's model map *explicitly* + sets ``supports_vision=False`` (or its bare-name variant does). Anything + else — missing key, lookup exception, ``supports_vision=True`` — returns + False so the request flows through to the provider. + +Implementation rule: only public LiteLLM symbols +------------------------------------------------ +``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the +typed module surface (see ``litellm.__init__`` lazy stubs) and are stable +across releases. The private ``_is_explicitly_disabled_factory`` and +``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade +can't silently break us. + +Why the previous round's strict YAML opt-in flag failed +------------------------------------------------------- +``supports_image_input: false`` was the YAML loader's setdefault. Operators +maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI +YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to +False and the streaming gate rejected every image turn. Sourcing capability +from LiteLLM's authoritative model map (which already says +``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable + +import litellm + +logger = logging.getLogger(__name__) + + +# Provider-name → LiteLLM model-prefix map. +# +# Owned here because ``app.services.provider_capabilities`` is the +# only edge that's safe to call from ``app.config``'s YAML loader at +# class-body init time. ``app.agents.new_chat.llm_config`` re-exports +# this constant under the historical ``PROVIDER_MAP`` name; placing the +# map there directly would re-introduce the +# ``app.config -> ... -> app.agents.new_chat.tools.generate_image -> +# app.config`` cycle that prompted the move. +_PROVIDER_PREFIX_MAP: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "XAI": "xai", + "BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "COMETAPI": "cometapi", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + +def _candidate_model_strings( + *, + provider: str | None, + model_name: str | None, + base_model: str | None, + custom_provider: str | None, +) -> list[tuple[str, str | None]]: + """Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates. + + LiteLLM's capability lookup is keyed by ``model`` + (optional) + ``custom_llm_provider``. Different config sources give us different + levels of detail, so we try the most-specific keys first and fall back + to bare model names so unannotated entries (e.g. an Azure deployment + pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the + map. Order matters — the first lookup that returns a definitive answer + wins for both helpers. + """ + candidates: list[tuple[str, str | None]] = [] + seen: set[tuple[str, str | None]] = set() + + def _add(model: str | None, llm_provider: str | None) -> None: + if not model: + return + key = (model, llm_provider) + if key in seen: + return + seen.add(key) + candidates.append(key) + + provider_prefix: str | None = None + if provider: + provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) + if custom_provider: + # ``custom_provider`` overrides everything for CUSTOM/proxy setups. + provider_prefix = custom_provider + + primary_model = base_model or model_name + bare_model = model_name + + # Most-specific first: provider-prefixed identifier with explicit + # custom_llm_provider so LiteLLM won't have to guess the provider via + # ``get_llm_provider``. + if primary_model and provider_prefix: + # e.g. "azure/gpt-5.4" + custom_llm_provider="azure" + if "/" in primary_model: + _add(primary_model, provider_prefix) + else: + _add(f"{provider_prefix}/{primary_model}", provider_prefix) + + # Bare base_model (or model_name) with provider hint — handles entries + # the upstream map keys without a provider prefix (most ``gpt-*`` and + # ``claude-*`` entries do this). + if primary_model: + _add(primary_model, provider_prefix) + + # Fallback to model_name when base_model differs (e.g. an Azure + # deployment whose model_name is the deployment id but base_model is the + # canonical OpenAI sku). + if bare_model and bare_model != primary_model: + if provider_prefix and "/" not in bare_model: + _add(f"{provider_prefix}/{bare_model}", provider_prefix) + _add(bare_model, provider_prefix) + _add(bare_model, None) + + return candidates + + +def derive_supports_image_input( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, + openrouter_input_modalities: Iterable[str] | None = None, +) -> bool: + """Best-effort capability flag for the new-chat selector and catalog. + + Resolution order (first definitive answer wins): + + 1. ``openrouter_input_modalities`` (when provided as a non-empty + iterable). OpenRouter exposes ``architecture.input_modalities`` per + model and that's the authoritative source for OR dynamic configs. + 2. ``litellm.supports_vision`` against each candidate identifier from + :func:`_candidate_model_strings`. Returns True as soon as any + candidate confirms vision support. + 3. Default ``True`` — the conservative-allow stance. An unknown / + newly-added / third-party-hosted model is *not* pre-judged. The + streaming safety net (:func:`is_known_text_only_chat_model`) is the + only place a False ever blocks; everywhere else, a False here would + just hide a usable model from the user. + + Returns: + True if the model can plausibly accept image input, False only when + OpenRouter explicitly says it can't. + """ + if openrouter_input_modalities is not None: + modalities = list(openrouter_input_modalities) + if modalities: + return "image" in modalities + # Empty list explicitly published by OR — treat as "no image". + return False + + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + if litellm.supports_vision( + model=model_string, custom_llm_provider=custom_llm_provider + ): + return True + except Exception as exc: + logger.debug( + "litellm.supports_vision raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # Default-allow. ``is_known_text_only_chat_model`` is the strict gate. + return True + + +def is_known_text_only_chat_model( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, +) -> bool: + """Strict opt-out probe for the streaming-task safety net. + + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False`` for at least one candidate identifier. Missing + key, lookup exception, or ``supports_vision=True`` all return False so + the streaming task lets the request through. This is the inverse-default + of :func:`derive_supports_image_input`. + + Why two functions + ----------------- + The selector wants "show me everything that's plausibly capable" — + default-allow. The safety net wants "block only when I'm certain it + can't" — default-pass. Mixing the two intents in a single function + leads to the regression we're fixing here. + """ + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + info = litellm.get_model_info( + model=model_string, custom_llm_provider=custom_llm_provider + ) + except Exception as exc: + logger.debug( + "litellm.get_model_info raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision`` + # may be missing, None, True, or False. We only fire on explicit + # False — None / missing / True all mean "don't block". + try: + value = info.get("supports_vision") # type: ignore[union-attr] + except AttributeError: + value = None + if value is False: + return True + + return False + + +__all__ = [ + "derive_supports_image_input", + "is_known_text_only_chat_model", +] diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py new file mode 100644 index 000000000..2fb37de21 --- /dev/null +++ b/surfsense_backend/app/services/quality_score.py @@ -0,0 +1,380 @@ +"""Pure-function quality scoring for Auto (Fastest) model selection. + +This module is import-free of any service / request-path dependencies. All +numbers are computed once during the OpenRouter refresh tick (or YAML load) +and cached on the cfg dict, so the chat hot path only does a precomputed +sort and a SHA256 pick. + +Score components (0-100 scale, higher is better): + +* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload + (provider prestige + ``created`` recency + pricing band + context window + + capabilities + narrow tiny/legacy slug penalty). +* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus + an operator-trust bonus (the operator deliberately picked this model). +* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints`` + responses; returns ``(gated, score_or_none)``. + +The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in +:mod:`app.services.openrouter_integration_service` because that's the only +caller that sees both halves. +""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# Tunables (constants, not flags) +# --------------------------------------------------------------------------- + +# Top-K size for deterministic spread inside the locked tier. +_QUALITY_TOP_K: int = 5 + +# Hard health gate: any cfg whose best non-null uptime is below this % +# is excluded from Auto-mode selection entirely. +_HEALTH_GATE_UPTIME_PCT: float = 90.0 + +# Health/static blend weight when a cfg has fresh /endpoints data. +_HEALTH_BLEND_WEIGHT: float = 0.5 + +# Static bonus applied to YAML cfgs because the operator hand-picked them. +_OPERATOR_TRUST_BONUS: int = 20 + +# /endpoints fan-out is bounded per refresh tick. +_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50 +_HEALTH_ENRICH_TOP_N_FREE: int = 30 +_HEALTH_ENRICH_CONCURRENCY: int = 15 +_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0 + +# If at least this fraction of /endpoints fetches fail in a refresh cycle, +# fall back to the previous cycle's last-good cache instead of writing +# partial / stale health values. +_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25 + +# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise +# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with +# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and +# blanket-penalising them suppresses high-quality picks. +_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = ( + "-1b-", + "-1.2b-", + "-1.5b-", + "-2b-", + "-3b-", + "gemma-3n", + "lfm-", + "-base", + "-distill", + ":nitro", + "-preview", +) + + +# --------------------------------------------------------------------------- +# Provider prestige tables +# --------------------------------------------------------------------------- + +# OpenRouter-side provider slug (the prefix before ``/`` in the model id). +# Tiers are coarse: frontier labs > strong open / fast-moving labs > +# specialist labs > everything else. +PROVIDER_PRESTIGE_OR: dict[str, int] = { + # Frontier labs + "openai": 50, + "anthropic": 50, + "google": 50, + "x-ai": 50, + # Strong open / fast-moving labs + "deepseek": 38, + "qwen": 38, + "meta-llama": 38, + "mistralai": 38, + "cohere": 38, + "nvidia": 38, + "alibaba": 38, + # Specialist / regional / strong second-tier + "microsoft": 28, + "01-ai": 28, + "minimax": 28, + "moonshot": 28, + "z-ai": 28, + "nousresearch": 28, + "ai21": 28, + "perplexity": 28, + # Smaller / niche providers + "liquid": 18, + "cognitivecomputations": 18, + "venice": 18, + "inflection": 18, +} + +# YAML provider field (the upstream API shape the operator selected). +PROVIDER_PRESTIGE_YAML: dict[str, int] = { + "AZURE_OPENAI": 50, + "OPENAI": 50, + "ANTHROPIC": 50, + "GOOGLE": 50, + "VERTEX_AI": 50, + "GEMINI": 50, + "XAI": 50, + "MISTRAL": 38, + "DEEPSEEK": 38, + "COHERE": 38, + "GROQ": 30, + "TOGETHER_AI": 28, + "FIREWORKS_AI": 28, + "PERPLEXITY": 28, + "MINIMAX": 28, + "BEDROCK": 28, + "OPENROUTER": 25, + "OLLAMA": 12, + "CUSTOM": 12, +} + + +# --------------------------------------------------------------------------- +# Pure scoring helpers +# --------------------------------------------------------------------------- + +# Calibrated against the live /api/v1/models bulk dump. Frontier models +# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5, +# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band; +# anything older trails off. +_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = ( + (60, 20), + (180, 16), + (365, 12), + (540, 9), + (730, 6), + (1095, 3), +) + + +def created_recency_signal(created_ts: int | None, now_ts: int) -> int: + """Return 0-20 based on how recently the model was published. + + Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for + YAML cfgs). Models without a usable timestamp get 0 (we don't penalise, + we just don't reward). + """ + if created_ts is None or created_ts <= 0 or now_ts <= 0: + return 0 + age_days = max(0, (now_ts - int(created_ts)) // 86_400) + for cutoff, score in _RECENCY_BANDS_DAYS: + if age_days <= cutoff: + return score + return 0 + + +def pricing_band( + prompt: str | float | int | None, + completion: str | float | int | None, +) -> int: + """Return 0-15 based on combined prompt+completion cost per 1M tokens. + + Higher-priced models tend to be the larger / more capable ones. A free + model returns 0 (we use other signals to rank free-vs-free instead). + Uncoercible inputs are treated as 0 rather than raising. + """ + + def _to_float(value) -> float: + if value is None: + return 0.0 + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + p = _to_float(prompt) + c = _to_float(completion) + total_per_million = (p + c) * 1_000_000 + + if total_per_million >= 20.0: + return 15 + if total_per_million >= 5.0: + return 12 + if total_per_million >= 1.0: + return 9 + if total_per_million >= 0.3: + return 6 + if total_per_million >= 0.05: + return 4 + if total_per_million > 0.0: + return 2 + return 0 + + +def context_signal(ctx: int | None) -> int: + """Return 0-10 based on the model's context window.""" + if not ctx or ctx <= 0: + return 0 + if ctx >= 1_000_000: + return 10 + if ctx >= 400_000: + return 8 + if ctx >= 200_000: + return 6 + if ctx >= 128_000: + return 4 + if ctx >= 100_000: + return 2 + return 0 + + +def capabilities_signal(supported_parameters: list[str] | None) -> int: + """Return 0-5 for capabilities that matter for our agent flows.""" + if not supported_parameters: + return 0 + params = set(supported_parameters) + score = 0 + if "tools" in params: + score += 2 + if "structured_outputs" in params or "response_format" in params: + score += 2 + if "reasoning" in params or "include_reasoning" in params: + score += 1 + return min(score, 5) + + +def slug_penalty(model_id: str) -> int: + """Return a non-positive number; matches the narrow tiny/legacy patterns.""" + if not model_id: + return 0 + needle = model_id.lower() + for pattern in _TINY_LEGACY_PENALTY_PATTERNS: + if pattern in needle: + return -10 + return 0 + + +def _provider_prestige_or(model_id: str) -> int: + if "/" not in model_id: + return 0 + slug = model_id.split("/", 1)[0].lower() + return PROVIDER_PRESTIGE_OR.get(slug, 15) + + +def static_score_or(or_model: dict, *, now_ts: int) -> int: + """Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale.""" + model_id = str(or_model.get("id", "")) + pricing = or_model.get("pricing") or {} + + score = ( + _provider_prestige_or(model_id) + + created_recency_signal(or_model.get("created"), now_ts) + + pricing_band(pricing.get("prompt"), pricing.get("completion")) + + context_signal(or_model.get("context_length")) + + capabilities_signal(or_model.get("supported_parameters")) + + slug_penalty(model_id) + ) + return max(0, min(100, int(score))) + + +def static_score_yaml(cfg: dict) -> int: + """Score a YAML-curated cfg on a 0-100 scale. + + Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately + listed this model. Pricing / context fall through to lazy ``litellm`` + lookups; failures are silent (we just lose those sub-points). + """ + provider = str(cfg.get("provider", "")).upper() + base = PROVIDER_PRESTIGE_YAML.get(provider, 15) + + model_name = cfg.get("model_name") or "" + litellm_params = cfg.get("litellm_params") or {} + lookup_name = ( + litellm_params.get("base_model") or litellm_params.get("model") or model_name + ) + + ctx = 0 + p_cost: float = 0.0 + c_cost: float = 0.0 + try: + from litellm import get_model_info # lazy: avoid cold-import cost + + info = get_model_info(lookup_name) or {} + ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0) + p_cost = float(info.get("input_cost_per_token") or 0.0) + c_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception: + # Unknown to litellm — that's fine for prestige+operator-bonus weighting. + pass + + score = ( + base + + _OPERATOR_TRUST_BONUS + + pricing_band(p_cost, c_cost) + + context_signal(ctx) + + slug_penalty(str(model_name)) + ) + return max(0, min(100, int(score))) + + +# --------------------------------------------------------------------------- +# Health aggregation +# --------------------------------------------------------------------------- + + +def _coerce_pct(value) -> float | None: + try: + if value is None: + return None + f = float(value) + except (TypeError, ValueError): + return None + if f < 0: + return None + # OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it + # as a 0-100 percentage. Normalise. + return f * 100.0 if f <= 1.0 else f + + +def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]: + """Pick the best (highest) non-null uptime across all endpoints. + + Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` > + ``uptime_last_5m``. Returns ``(uptime_pct, window_used)``. + """ + for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + return max(values), window + return None, None + + +def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]: + """Aggregate a model's per-endpoint health into ``(gated, score_or_none)``. + + Hard gate (returns ``(True, None)``): + * ``endpoints`` empty, + * no endpoint reports ``status == 0`` (OK), or + * best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``. + + On a pass, returns a 0-100 health score blending uptime, status, and a + freshness-weighted recent uptime sample. + """ + if not endpoints: + return True, None + + any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints) + if not any_ok: + return True, None + + best_uptime, _ = _best_uptime(endpoints) + if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT: + return True, None + + # Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing. + freshness = None + for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + freshness = max(values) + break + + uptime_term = best_uptime + status_term = 100.0 if any_ok else 0.0 + freshness_term = freshness if freshness is not None else best_uptime + + score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term + return False, max(0.0, min(100.0, score)) diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py new file mode 100644 index 000000000..0040e5a5b --- /dev/null +++ b/surfsense_backend/app/services/quota_checked_vision_llm.py @@ -0,0 +1,105 @@ +""" +Vision LLM proxy that enforces premium credit quota on every ``ainvoke``. + +Used by :func:`app.services.llm_service.get_vision_llm` so callers in the +indexing pipeline (file processors, connector indexers, etl pipeline) can +keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)`` +— without threading ``user_id`` through every parser. The wrapper looks like +a chat model from the outside; on the inside it routes each call through +``billable_call`` so the user's premium credit pool is reserved → finalized +or released, and a ``TokenUsage`` audit row is written. + +Free configs are returned unwrapped from ``get_vision_llm`` (they do not +need quota enforcement) so this class only ever wraps premium configs. + +Why a wrapper instead of plumbing ``user_id`` through every caller: + +* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive, + Dropbox, local-folder, file-processor, ETL pipeline) each calling + ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is + invasive, error-prone, and easy for a future indexer to forget. +* Per the design (issue M), we always debit the *search-space owner*, not + the triggering user, so ``user_id`` is fully derivable from the search + space the caller is already operating on. The wrapper captures it once + at construction time. +* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each + call run this coroutine"; subclassing isn't safe across versions because + it derives from ``BaseChatModel`` which expects specific Pydantic shapes. + Composition via attribute proxying (``__getattr__``) is robust to + upstream changes — every method other than ``ainvoke`` falls through to + the inner LLM unchanged. +""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from app.services.billable_calls import QuotaInsufficientError, billable_call + +logger = logging.getLogger(__name__) + + +class QuotaCheckedVisionLLM: + """Composition wrapper around a langchain chat model that enforces + premium credit quota on every ``ainvoke``. + + Anything other than ``ainvoke`` is forwarded to the inner model so + ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all + still work — they simply bypass quota enforcement, which is fine + because the indexing pipeline only ever calls ``ainvoke`` today. + """ + + def __init__( + self, + inner_llm: Any, + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None, + usage_type: str = "vision_extraction", + ) -> None: + self._inner = inner_llm + self._user_id = user_id + self._search_space_id = search_space_id + self._billing_tier = billing_tier + self._base_model = base_model + self._quota_reserve_tokens = quota_reserve_tokens + self._usage_type = usage_type + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + """Proxied async invoke that runs the underlying call inside + ``billable_call``. + + Raises: + QuotaInsufficientError: when the user has exhausted their + premium credit pool. Caller (``etl_pipeline_service._extract_image``) + catches this and falls back to the document parser. + """ + async with billable_call( + user_id=self._user_id, + search_space_id=self._search_space_id, + billing_tier=self._billing_tier, + base_model=self._base_model, + quota_reserve_tokens=self._quota_reserve_tokens, + usage_type=self._usage_type, + call_details={"model": self._base_model}, + ): + return await self._inner.ainvoke(input, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Forward everything else (``invoke``, ``astream``, ``bind``, + ``with_structured_output``, …) to the inner model. + + ``__getattr__`` is only consulted when the attribute is *not* + already found on the proxy, which is exactly the contract we + want — methods we override stay on the proxy, the rest fall + through. + """ + return getattr(self._inner, name) + + +__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"] diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py index a3ec7aed0..310c3eb5e 100644 --- a/surfsense_backend/app/services/token_quota_service.py +++ b/surfsense_backend/app/services/token_quota_service.py @@ -22,6 +22,71 @@ from app.config import config logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Per-call reservation estimator (USD micro-units) +# --------------------------------------------------------------------------- + +# Minimum reserve in micros so a user with $0.0001 left can still make a tiny +# request, and so models without registered pricing reserve at least +# something while the call runs (debited 0 at finalize anyway when their +# cost can't be resolved). +_QUOTA_MIN_RESERVE_MICROS = 100 + + +def estimate_call_reserve_micros( + *, + base_model: str, + quota_reserve_tokens: int | None, +) -> int: + """Return the number of micro-USD to reserve for one premium call. + + Computes a worst-case upper bound from LiteLLM's per-token pricing + table: + + reserve_usd ≈ reserve_tokens x (input_cost + output_cost) + + so the math scales with model cost — Claude Opus + 4K reserve_tokens + naturally reserves ≈ $0.36, while a cheap model reserves only a few + cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]`` + so a misconfigured "$1000/M" model can't lock the whole balance on + one call. + + If ``litellm.get_model_info`` raises (model unknown) we fall back to + the floor — 100 micros / $0.0001 — which is enough to gate a sane + request without over-reserving for a model whose pricing the + operator hasn't declared yet. + """ + reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL + if reserve_tokens <= 0: + reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL + + try: + from litellm import get_model_info + + info = get_model_info(base_model) if base_model else {} + input_cost = float(info.get("input_cost_per_token") or 0.0) + output_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception as exc: + logger.debug( + "[quota_reserve] cost lookup failed for base_model=%s: %s", + base_model, + exc, + ) + input_cost = 0.0 + output_cost = 0.0 + + if input_cost == 0.0 and output_cost == 0.0: + return _QUOTA_MIN_RESERVE_MICROS + + reserve_usd = reserve_tokens * (input_cost + output_cost) + reserve_micros = round(reserve_usd * 1_000_000) + if reserve_micros < _QUOTA_MIN_RESERVE_MICROS: + reserve_micros = _QUOTA_MIN_RESERVE_MICROS + if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS: + reserve_micros = config.QUOTA_MAX_RESERVE_MICROS + return reserve_micros + + class QuotaScope(StrEnum): ANONYMOUS = "anonymous" PREMIUM = "premium" @@ -444,8 +509,16 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - reserve_tokens: int, + reserve_micros: int, ) -> QuotaResult: + """Reserve ``reserve_micros`` (USD micro-units) from the user's + premium credit balance. + + ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are + all in micro-USD on this code path; callers (chat stream, + token-status route, FE display) convert to dollars by dividing + by 1_000_000. + """ from app.db import User user = ( @@ -465,11 +538,11 @@ class TokenQuotaService: limit=0, ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved - effective = used + reserved + reserve_tokens + effective = used + reserved + reserve_micros if effective > limit: remaining = max(0, limit - used - reserved) await db_session.rollback() @@ -482,10 +555,10 @@ class TokenQuotaService: remaining=remaining, ) - user.premium_tokens_reserved = reserved + reserve_tokens + user.premium_credit_micros_reserved = reserved + reserve_micros await db_session.commit() - new_reserved = reserved + reserve_tokens + new_reserved = reserved + reserve_micros remaining = max(0, limit - used - new_reserved) warning_threshold = int(limit * 0.8) @@ -510,9 +583,12 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - actual_tokens: int, - reserved_tokens: int, + actual_micros: int, + reserved_micros: int, ) -> QuotaResult: + """Settle the reservation: release ``reserved_micros`` and debit + ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD). + """ from app.db import User user = ( @@ -529,16 +605,18 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros + ) + user.premium_credit_micros_used = ( + user.premium_credit_micros_used + actual_micros ) - user.premium_tokens_used = user.premium_tokens_used + actual_tokens await db_session.commit() - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) @@ -562,8 +640,13 @@ class TokenQuotaService: async def premium_release( db_session: AsyncSession, user_id: Any, - reserved_tokens: int, + reserved_micros: int, ) -> None: + """Release ``reserved_micros`` previously held by ``premium_reserve``. + + Used when a request fails before finalize (so the reservation + doesn't leak credit). + """ from app.db import User user = ( @@ -576,8 +659,8 @@ class TokenQuotaService: .scalar_one_or_none() ) if user is not None: - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros ) await db_session.commit() @@ -598,9 +681,9 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 9aa8c6e70..9406d9be4 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -16,11 +16,14 @@ from __future__ import annotations import dataclasses import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from contextvars import ContextVar from dataclasses import dataclass, field from typing import Any from uuid import UUID +import litellm from litellm.integrations.custom_logger import CustomLogger from sqlalchemy.ext.asyncio import AsyncSession @@ -35,6 +38,8 @@ class TokenCallRecord: prompt_tokens: int completion_tokens: int total_tokens: int + cost_micros: int = 0 + call_kind: str = "chat" @dataclass @@ -49,6 +54,8 @@ class TurnTokenAccumulator: prompt_tokens: int, completion_tokens: int, total_tokens: int, + cost_micros: int = 0, + call_kind: str = "chat", ) -> None: self.calls.append( TokenCallRecord( @@ -56,20 +63,28 @@ class TurnTokenAccumulator: prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) ) def per_message_summary(self) -> dict[str, dict[str, int]]: - """Return token counts grouped by model name.""" + """Return token counts (and cost) grouped by model name.""" by_model: dict[str, dict[str, int]] = {} for c in self.calls: entry = by_model.setdefault( c.model, - {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost_micros": 0, + }, ) entry["prompt_tokens"] += c.prompt_tokens entry["completion_tokens"] += c.completion_tokens entry["total_tokens"] += c.total_tokens + entry["cost_micros"] += c.cost_micros return by_model @property @@ -84,6 +99,21 @@ class TurnTokenAccumulator: def total_completion_tokens(self) -> int: return sum(c.completion_tokens for c in self.calls) + @property + def total_cost_micros(self) -> int: + """Sum of per-call ``cost_micros`` across the entire turn. + + Used by ``stream_new_chat`` to debit a premium turn's actual + provider cost (in micro-USD) from the user's premium credit + balance. ``cost_micros`` per call is captured by + ``TokenTrackingCallback.async_log_success_event`` from + ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost), + with multiple fallback paths so OpenRouter dynamic models and + custom Azure deployments still bill correctly when our + ``pricing_registration`` ran at startup. + """ + return sum(c.cost_micros for c in self.calls) + def serialized_calls(self) -> list[dict[str, Any]]: return [dataclasses.asdict(c) for c in self.calls] @@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar( def start_turn() -> TurnTokenAccumulator: - """Create a fresh accumulator for the current async context and return it.""" + """Create a fresh accumulator for the current async context and return it. + + NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For + short-lived per-call billable wrappers (image generation REST endpoint, + vision LLM during indexing) prefer :func:`scoped_turn`, which uses a + ContextVar reset token to restore the *previous* accumulator on exit and + avoids leaking call records across reservations (issue B). + """ acc = TurnTokenAccumulator() _turn_accumulator.set(acc) logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc)) @@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() +@asynccontextmanager +async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: + """Async context manager that scopes a fresh ``TurnTokenAccumulator`` + for the duration of the ``async with`` block, then *resets* the + ContextVar to its previous value on exit. + + This is the safe primitive for per-call billable operations + (image generation, vision LLM extraction, podcasts) that may run + inside an outer chat turn or be called sequentially from the same + background worker. Using ``ContextVar.set`` without ``reset`` (as + :func:`start_turn` does) would leak the inner accumulator into the + outer scope, causing the outer chat turn to debit cost twice. + + Usage:: + + async with scoped_turn() as acc: + await llm.ainvoke(...) + # acc.total_cost_micros captures cost from the LiteLLM callback + # Outer accumulator (if any) is restored here. + """ + acc = TurnTokenAccumulator() + token = _turn_accumulator.set(acc) + logger.debug( + "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)", + id(acc), + token, + ) + try: + yield acc + finally: + _turn_accumulator.reset(token) + logger.debug( + "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)", + id(acc), + len(acc.calls), + acc.total_cost_micros, + ) + + +def _extract_cost_usd( + kwargs: dict[str, Any], + response_obj: Any, + model: str, + prompt_tokens: int, + completion_tokens: int, + is_image: bool = False, +) -> float: + """Best-effort USD cost extraction for a single LLM/image call. + + Tries four sources in priority order and returns the first that + yields a positive number; returns 0.0 if all four fail (the call + will then debit nothing from the user's balance — fail-safe). + + Sources: + 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback + field, populated for ``Router.acompletion`` since PR #12500. + 2. ``response_obj._hidden_params["response_cost"]`` — same value + exposed on the response itself. + 3. ``litellm.completion_cost(completion_response=response_obj)`` + — recompute from the response and LiteLLM's pricing table. + 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)`` + — manual fallback for OpenRouter/custom-Azure models that + only resolve via aliases registered by + ``pricing_registration`` at startup. **Skipped for image + responses** — ``cost_per_token`` does not support ``ImageResponse`` + and would raise; the cost map for image-gen lives in different + keys (``output_cost_per_image``) handled by ``completion_cost``. + """ + cost = kwargs.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + hidden = getattr(response_obj, "_hidden_params", None) or {} + if isinstance(hidden, dict): + cost = hidden.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + try: + value = float(litellm.completion_cost(completion_response=response_obj)) + if value > 0: + return value + except Exception as exc: + if is_image: + # Image-gen path: OpenRouter's image responses can omit + # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator`` + # then *raises* (no cost map for OpenRouter image models). + # Bail out with a warning rather than falling through to + # cost_per_token (which is also incompatible with ImageResponse). + logger.warning( + "[TokenTracking] completion_cost failed for image model=%s " + "(provider may have omitted usage.cost). Debiting 0. " + "Cause: %s", + model, + exc, + ) + return 0.0 + logger.debug( + "[TokenTracking] completion_cost failed for model=%s: %s", model, exc + ) + + if is_image: + # Never call cost_per_token for ImageResponse — keys mismatch and + # the function is documented chat-only. + return 0.0 + + if model and (prompt_tokens > 0 or completion_tokens > 0): + try: + prompt_cost, completion_cost = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + value = float(prompt_cost) + float(completion_cost) + if value > 0: + return value + except Exception as exc: + logger.debug( + "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc + ) + + return 0.0 + + class TokenTrackingCallback(CustomLogger): """LiteLLM callback that captures token usage into the turn accumulator.""" @@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger): ) return + # Detect image generation responses — they have a different usage + # shape (ImageUsage with input_tokens/output_tokens) and require a + # different cost-extraction path. We probe by class name to avoid a + # hard import dependency on litellm internals. + response_cls = type(response_obj).__name__ + is_image = response_cls == "ImageResponse" + usage = getattr(response_obj, "usage", None) if not usage: logger.debug( @@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger): ) return - prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 - completion_tokens = getattr(usage, "completion_tokens", 0) or 0 - total_tokens = getattr(usage, "total_tokens", 0) or 0 + if is_image: + # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens`` + # (not prompt_tokens/completion_tokens). Several providers + # populate only one or neither (e.g. OpenRouter's gpt-image-1 + # passes through `input_tokens` from the prompt but no + # completion); fall through gracefully to 0. + prompt_tokens = getattr(usage, "input_tokens", 0) or 0 + completion_tokens = getattr(usage, "output_tokens", 0) or 0 + total_tokens = ( + getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens + ) + call_kind = "image_generation" + else: + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + total_tokens = getattr(usage, "total_tokens", 0) or 0 + call_kind = "chat" model = kwargs.get("model", "unknown") + cost_usd = _extract_cost_usd( + kwargs=kwargs, + response_obj=response_obj, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_image=is_image, + ) + cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0 + + if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0): + logger.warning( + "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d " + "kind=%s — debiting 0. Register pricing via pricing_registration or YAML " + "input_cost_per_token/output_cost_per_token (or rely on response_cost " + "for image generation).", + model, + prompt_tokens, + completion_tokens, + call_kind, + ) + acc.add( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) logger.info( - "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", + "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d " + "cost=$%.6f (%d micros) (accumulator now has %d calls)", model, + call_kind, prompt_tokens, completion_tokens, total_tokens, + cost_usd, + cost_micros, len(acc.calls), ) @@ -168,6 +388,7 @@ async def record_token_usage( prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0, + cost_micros: int = 0, model_breakdown: dict[str, Any] | None = None, call_details: dict[str, Any] | None = None, thread_id: int | None = None, @@ -185,6 +406,7 @@ async def record_token_usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, model_breakdown=model_breakdown, call_details=call_details, thread_id=thread_id, @@ -194,11 +416,12 @@ async def record_token_usage( ) session.add(record) logger.debug( - "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", + "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d", usage_type, prompt_tokens, completion_tokens, total_tokens, + cost_micros, ) return record except Exception: diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0d782ab2b..ed5de921c 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,6 +3,8 @@ from typing import Any from litellm import Router +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 @@ -108,10 +110,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" else: - provider = config.get("provider", "").upper() provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{config['model_name']}" @@ -120,8 +123,13 @@ class VisionLLMRouterService: "api_key": config.get("api_key"), } - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base if config.get("api_version"): litellm_params["api_version"] = config["api_version"] diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index 5b1f2cd13..b23359f36 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -1,10 +1,25 @@ -"""Celery tasks package.""" +"""Celery tasks package. + +Also hosts the small helpers every async celery task should use to +spin up its event loop. See :func:`run_async_celery_task` for the +canonical pattern. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Awaitable, Callable +from typing import TypeVar from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool from app.config import config +logger = logging.getLogger(__name__) + _celery_engine = None _celery_session_maker = None @@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker: _celery_engine, expire_on_commit=False ) return _celery_session_maker + + +def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None: + """Drop the shared ``app.db.engine`` connection pool synchronously. + + The shared engine (used by ``shielded_async_session`` and most + routes / services) is a module-level singleton with a real pool. + Each celery task creates a fresh ``asyncio`` event loop; asyncpg + connections cache a reference to whichever loop opened them. When + a subsequent task's loop pulls a stale connection from the pool, + SQLAlchemy's ``pool_pre_ping`` checkout crashes with:: + + AttributeError: 'NoneType' object has no attribute 'send' + File ".../asyncio/proactor_events.py", line 402, in _loop_writing + self._write_fut = self._loop._proactor.send(self._sock, data) + + or hangs forever inside the asyncpg ``Connection._cancel`` cleanup + coroutine that can never run because its loop is gone. + + Disposing the engine forces the pool to drop every cached + connection so the next checkout opens a fresh one on the current + loop. Safe to call from a task's finally block; failure is logged + but never propagated. + """ + try: + from app.db import engine as shared_engine + + loop.run_until_complete(shared_engine.dispose()) + except Exception: + logger.warning("Shared DB engine dispose() failed", exc_info=True) + + +T = TypeVar("T") + + +def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T: + """Run an async coroutine inside a fresh event loop with proper + DB-engine cleanup. + + This is the canonical entry point for every async celery task. + It performs three responsibilities that were previously copy-pasted + (incorrectly) across each task module: + + 1. Create a fresh ``asyncio`` loop and install it on the current + thread (celery's ``--pool=solo`` runs every task on the main + thread, but other pool types don't). + 2. Dispose the shared ``app.db.engine`` BEFORE the task runs so + any stale connections left over from a previous task's loop + are dropped — defends against tasks that crashed without + cleaning up. + 3. Dispose the shared engine AFTER the task runs so the + connections we opened on this loop are released before the + loop closes (avoids ``coroutine 'Connection._cancel' was + never awaited`` warnings and the next-task hang). + + Use as:: + + @celery_app.task(name="my_task", bind=True) + def my_task(self, *args): + return run_async_celery_task(lambda: _my_task_impl(*args)) + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Defense-in-depth: prior task may have crashed before + # disposing. Idempotent — no-op if pool is already empty. + _dispose_shared_db_engine(loop) + return loop.run_until_complete(coro_factory()) + finally: + # Drop any connections this task opened so they don't leak + # into the next task's loop. + _dispose_shared_db_engine(loop) + with contextlib.suppress(Exception): + loop.run_until_complete(loop.shutdown_asyncgens()) + with contextlib.suppress(Exception): + asyncio.set_event_loop(None) + loop.close() + + +__all__ = [ + "get_celery_session_maker", + "run_async_celery_task", +] diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index fe1ac19d3..08d96cfa0 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -4,7 +4,7 @@ import logging import traceback from app.celery_app import celery_app -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -49,22 +49,15 @@ def index_notion_pages_task( end_date: str, ): """Celery task to index Notion pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_notion_pages( + return run_async_celery_task( + lambda: _index_notion_pages( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_notion_pages", connector_id) raise - finally: - loop.close() async def _index_notion_pages( @@ -95,19 +88,11 @@ def index_github_repos_task( end_date: str, ): """Celery task to index GitHub repositories.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_github_repos( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_github_repos( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_github_repos( @@ -138,19 +123,11 @@ def index_confluence_pages_task( end_date: str, ): """Celery task to index Confluence pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_confluence_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_confluence_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_confluence_pages( @@ -181,22 +158,15 @@ def index_google_calendar_events_task( end_date: str, ): """Celery task to index Google Calendar events.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_google_calendar_events( + return run_async_celery_task( + lambda: _index_google_calendar_events( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_google_calendar_events", connector_id) raise - finally: - loop.close() async def _index_google_calendar_events( @@ -227,19 +197,11 @@ def index_google_gmail_messages_task( end_date: str, ): """Celery task to index Google Gmail messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_gmail_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_google_gmail_messages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_google_gmail_messages( @@ -269,22 +231,14 @@ def index_google_drive_files_task( items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options' ): """Celery task to index Google Drive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_drive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_google_drive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_google_drive_files( @@ -317,22 +271,14 @@ def index_onedrive_files_task( items_dict: dict, ): """Celery task to index OneDrive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_onedrive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_onedrive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_onedrive_files( @@ -365,22 +311,14 @@ def index_dropbox_files_task( items_dict: dict, ): """Celery task to index Dropbox folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_dropbox_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_dropbox_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_dropbox_files( @@ -414,19 +352,11 @@ def index_elasticsearch_documents_task( end_date: str, ): """Celery task to index Elasticsearch documents.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_elasticsearch_documents( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_elasticsearch_documents( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_elasticsearch_documents( @@ -457,22 +387,15 @@ def index_crawled_urls_task( end_date: str, ): """Celery task to index Web page Urls.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_crawled_urls( + return run_async_celery_task( + lambda: _index_crawled_urls( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_crawled_urls", connector_id) raise - finally: - loop.close() async def _index_crawled_urls( @@ -503,19 +426,11 @@ def index_bookstack_pages_task( end_date: str, ): """Celery task to index BookStack pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_bookstack_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_bookstack_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_bookstack_pages( @@ -546,19 +461,11 @@ def index_composio_connector_task( end_date: str | None, ): """Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio).""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_composio_connector( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_composio_connector( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_composio_connector( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py index c2dbe7700..5d6bde6c1 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py @@ -11,7 +11,7 @@ from app.db import Document from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str): document_id: ID of document to reindex user_id: ID of user who edited the document """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reindex_document(document_id, user_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _reindex_document(document_id, user_id)) async def _reindex_document(document_id: int, user_id: str): diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 9d12f91f6..c78e376bd 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -11,7 +11,7 @@ from app.celery_app import celery_app from app.config import config from app.services.notification_service import NotificationService from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.tasks.connector_indexers.local_folder_indexer import ( index_local_folder, index_uploaded_files, @@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int): ) def delete_document_task(self, document_id: int): """Celery task to delete a document and its chunks in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_document_background(document_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _delete_document_background(document_id)) async def _delete_document_background(document_id: int) -> None: @@ -153,14 +148,9 @@ def delete_folder_documents_task( folder_subtree_ids: list[int] | None = None, ): """Celery task to delete documents first, then the folder rows.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _delete_folder_documents(document_ids, folder_subtree_ids) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_folder_documents(document_ids, folder_subtree_ids) + ) async def _delete_folder_documents( @@ -209,12 +199,9 @@ async def _delete_folder_documents( ) def delete_search_space_task(self, search_space_id: int): """Celery task to delete a search space and heavy child rows in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_search_space_background(search_space_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_search_space_background(search_space_id) + ) async def _delete_search_space_background(search_space_id: int) -> None: @@ -269,18 +256,11 @@ def process_extension_document_task( search_space_id: ID of the search space user_id: ID of the user """ - # Create a new event loop for this task - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_extension_document( - individual_document_dict, search_space_id, user_id - ) + return run_async_celery_task( + lambda: _process_extension_document( + individual_document_dict, search_space_id, user_id ) - finally: - loop.close() + ) async def _process_extension_document( @@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st search_space_id: ID of the search space user_id: ID of the user """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _process_youtube_video(url, search_space_id, user_id) + ) async def _process_youtube_video(url: str, search_space_id: int, user_id: str): @@ -573,12 +549,9 @@ def process_file_upload_task( except Exception as e: logger.warning(f"[process_file_upload] Could not get file size: {e}") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_upload(file_path, filename, search_space_id, user_id) + run_async_celery_task( + lambda: _process_file_upload(file_path, filename, search_space_id, user_id) ) logger.info( f"[process_file_upload] Task completed successfully for: {filename}" @@ -589,8 +562,6 @@ def process_file_upload_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _process_file_upload( @@ -811,25 +782,17 @@ def process_file_upload_with_document_task( "File may have been removed before syncing could start." ) # Mark document as failed since file is missing - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _mark_document_failed( - document_id, - "File not found. Please re-upload the file.", - ) + run_async_celery_task( + lambda: _mark_document_failed( + document_id, + "File not found. Please re-upload the file.", ) - finally: - loop.close() + ) return - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_with_document( + run_async_celery_task( + lambda: _process_file_with_document( document_id, temp_path, filename, @@ -849,8 +812,6 @@ def process_file_upload_with_document_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _mark_document_failed(document_id: int, reason: str): @@ -1119,22 +1080,16 @@ def process_circleback_meeting_task( search_space_id: ID of the search space connector_id: ID of the Circleback connector (for deletion support) """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_circleback_meeting( - meeting_id, - meeting_name, - markdown_content, - metadata, - search_space_id, - connector_id, - ) + return run_async_celery_task( + lambda: _process_circleback_meeting( + meeting_id, + meeting_name, + markdown_content, + metadata, + search_space_id, + connector_id, ) - finally: - loop.close() + ) async def _process_circleback_meeting( @@ -1291,25 +1246,19 @@ def index_local_folder_task( target_file_paths: list[str] | None = None, ): """Celery task to index a local folder. Config is passed directly — no connector row.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_local_folder_async( - search_space_id=search_space_id, - user_id=user_id, - folder_path=folder_path, - folder_name=folder_name, - exclude_patterns=exclude_patterns, - file_extensions=file_extensions, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - target_file_paths=target_file_paths, - ) + return run_async_celery_task( + lambda: _index_local_folder_async( + search_space_id=search_space_id, + user_id=user_id, + folder_path=folder_path, + folder_name=folder_name, + exclude_patterns=exclude_patterns, + file_extensions=file_extensions, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + target_file_paths=target_file_paths, ) - finally: - loop.close() + ) async def _index_local_folder_async( @@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task( processing_mode: str = "basic", ): """Celery task to index files uploaded from the desktop app.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_uploaded_folder_files_async( - search_space_id=search_space_id, - user_id=user_id, - folder_name=folder_name, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - file_mappings=file_mappings, - use_vision_llm=use_vision_llm, - processing_mode=processing_mode, - ) + return run_async_celery_task( + lambda: _index_uploaded_folder_files_async( + search_space_id=search_space_id, + user_id=user_id, + folder_name=folder_name, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + file_mappings=file_mappings, + use_vision_llm=use_vision_llm, + processing_mode=processing_mode, ) - finally: - loop.close() + ) async def _index_uploaded_folder_files_async( @@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str: @celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) def ai_sort_search_space_task(self, search_space_id: int, user_id: str): """Full AI sort for all documents in a search space.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_search_space_async(search_space_id, user_id) + ) async def _ai_sort_search_space_async(search_space_id: int, user_id: str): @@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str): ) def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): """Incremental AI sort for a single document after indexing.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _ai_sort_document_async(search_space_id, user_id, document_id) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_document_async(search_space_id, user_id, document_id) + ) async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): diff --git a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py index 98b107af3..c6c8666f5 100644 --- a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py @@ -2,14 +2,13 @@ from __future__ import annotations -import asyncio import logging from app.celery_app import celery_app from app.db import SearchSourceConnector from app.schemas.obsidian_plugin import NotePayload from app.services.obsidian_plugin_indexer import upsert_note -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -22,18 +21,13 @@ def index_obsidian_attachment_task( user_id: str, ) -> None: """Process one Obsidian non-markdown attachment asynchronously.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_obsidian_attachment( - connector_id=connector_id, - payload_data=payload_data, - user_id=user_id, - ) + return run_async_celery_task( + lambda: _index_obsidian_attachment( + connector_id=connector_id, + payload_data=payload_data, + user_id=user_id, ) - finally: - loop.close() + ) async def _index_obsidian_attachment( diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 953011ecf..8b311576e 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -3,14 +3,22 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState from app.celery_app import celery_app +from app.config import config as app_config from app.db import Podcast, PodcastStatus -from app.tasks.celery_tasks import get_celery_session_maker +from app.services.billable_calls import ( + BillingSettlementError, + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -28,6 +36,13 @@ if sys.platform.startswith("win"): # ============================================================================= +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_content_podcast", bind=True) def generate_content_podcast_task( self, @@ -40,27 +55,22 @@ def generate_content_podcast_task( Celery task to generate podcast from source content. Updates existing podcast record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_content_podcast( + return run_async_celery_task( + lambda: _generate_content_podcast( podcast_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating content podcast: {e!s}") - loop.run_until_complete(_mark_podcast_failed(podcast_id)) + try: + run_async_celery_task(lambda: _mark_podcast_failed(podcast_id)) + except Exception: + logger.exception("Failed to mark podcast %s as failed", podcast_id) return {"status": "failed", "podcast_id": podcast_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_podcast_failed(podcast_id: int) -> None: @@ -96,6 +106,31 @@ async def _generate_content_podcast( podcast.status = PodcastStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=podcast.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "Podcast %s: cannot resolve billing for search_space=%s: %s", + podcast.id, + search_space_id, + resolve_err, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "podcast_title": podcast.title, @@ -109,9 +144,52 @@ async def _generate_content_podcast( db_session=session, ) - graph_result = await podcaster_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, + usage_type="podcast_generation", + call_details={ + "podcast_id": podcast.id, + "title": podcast.title, + "thread_id": podcast.thread_id, + }, + billable_session_factory=_celery_billable_session, + ): + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "Podcast %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + podcast.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "premium_quota_exhausted", + } + except BillingSettlementError: + logger.exception( + "Podcast %s: premium billing settlement failed", + podcast.id, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_settlement_failed", + } podcast_transcript = graph_result.get("podcast_transcript", []) file_path = graph_result.get("final_podcast_file_path", "") @@ -133,7 +211,14 @@ async def _generate_content_podcast( podcast.podcast_transcript = serializable_transcript podcast.file_location = file_path podcast.status = PodcastStatus.READY + logger.info( + "Podcast %s: committing READY transcript_entries=%d file=%s", + podcast.id, + len(serializable_transcript), + file_path, + ) await session.commit() + logger.info("Podcast %s: READY commit complete", podcast.id) logger.info(f"Successfully generated podcast: {podcast.id}") diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 373f04b48..e41251407 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -7,7 +7,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.utils.indexing_locks import is_connector_indexing_locked logger = logging.getLogger(__name__) @@ -20,15 +20,7 @@ def check_periodic_schedules_task(): This task runs every minute and triggers indexing for any connector whose next_scheduled_at time has passed. """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_check_and_trigger_schedules()) - finally: - loop.close() + return run_async_celery_task(_check_and_trigger_schedules) async def _check_and_trigger_schedules(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index e05ae9435..d51c85dee 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -34,7 +34,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.config import config from app.db import Document, DocumentStatus, Notification -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task(): Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task. Also marks associated pending/processing documents as failed. """ - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + async def _both() -> None: + await _cleanup_stale_notifications() + await _cleanup_stale_document_processing_notifications() - try: - loop.run_until_complete(_cleanup_stale_notifications()) - loop.run_until_complete(_cleanup_stale_document_processing_notifications()) - finally: - loop.close() + return run_async_celery_task(_both) async def _cleanup_stale_notifications(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py index 3aee1a360..ace6ef7ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import logging from datetime import UTC, datetime, timedelta @@ -18,7 +17,7 @@ from app.db import ( PremiumTokenPurchaseStatus, ) from app.routes import stripe_routes -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None: @celery_app.task(name="reconcile_pending_stripe_page_purchases") def reconcile_pending_stripe_page_purchases_task(): """Recover paid purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_page_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_page_purchases) async def _reconcile_pending_page_purchases() -> None: @@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None: @celery_app.task(name="reconcile_pending_stripe_token_purchases") def reconcile_pending_stripe_token_purchases_task(): """Recover paid token purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_token_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_token_purchases) async def _reconcile_pending_token_purchases() -> None: diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index 7880b385f..08f22140c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -3,14 +3,22 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select from app.agents.video_presentation.graph import graph as video_presentation_graph from app.agents.video_presentation.state import State as VideoPresentationState from app.celery_app import celery_app +from app.config import config as app_config from app.db import VideoPresentation, VideoPresentationStatus -from app.tasks.celery_tasks import get_celery_session_maker +from app.services.billable_calls import ( + BillingSettlementError, + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -23,6 +31,13 @@ if sys.platform.startswith("win"): ) +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_video_presentation", bind=True) def generate_video_presentation_task( self, @@ -35,27 +50,30 @@ def generate_video_presentation_task( Celery task to generate video presentation from source content. Updates existing video presentation record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_video_presentation( + return run_async_celery_task( + lambda: _generate_video_presentation( video_presentation_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating video presentation: {e!s}") - loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id)) + # Mark FAILED in a fresh loop — the previous loop is closed. + # Swallow secondary failures; the row will simply stay in + # GENERATING and be flushed by the periodic stale cleanup. + try: + run_async_celery_task( + lambda: _mark_video_presentation_failed(video_presentation_id) + ) + except Exception: + logger.exception( + "Failed to mark video presentation %s as failed", + video_presentation_id, + ) return {"status": "failed", "video_presentation_id": video_presentation_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_video_presentation_failed(video_presentation_id: int) -> None: @@ -97,6 +115,32 @@ async def _generate_video_presentation( video_pres.status = VideoPresentationStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=video_pres.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "VideoPresentation %s: cannot resolve billing for " + "search_space=%s: %s", + video_pres.id, + search_space_id, + resolve_err, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "video_title": video_pres.title, @@ -110,9 +154,52 @@ async def _generate_video_presentation( db_session=session, ) - graph_result = await video_presentation_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, + usage_type="video_presentation_generation", + call_details={ + "video_presentation_id": video_pres.id, + "title": video_pres.title, + "thread_id": video_pres.thread_id, + }, + billable_session_factory=_celery_billable_session, + ): + graph_result = await video_presentation_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "VideoPresentation %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + video_pres.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "premium_quota_exhausted", + } + except BillingSettlementError: + logger.exception( + "VideoPresentation %s: premium billing settlement failed", + video_pres.id, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_settlement_failed", + } # Serialize slides (parsed content + audio info merged) slides_raw = graph_result.get("slides", []) @@ -143,7 +230,14 @@ async def _generate_video_presentation( video_pres.slides = serializable_slides video_pres.scene_codes = serializable_scene_codes video_pres.status = VideoPresentationStatus.READY + logger.info( + "VideoPresentation %s: committing READY slides=%d scene_codes=%d", + video_pres.id, + len(serializable_slides), + len(serializable_scene_codes), + ) await session.commit() + logger.info("VideoPresentation %s: READY commit complete", video_pres.id) logger.info(f"Successfully generated video presentation: {video_pres.id}") diff --git a/surfsense_backend/app/tasks/chat/content_builder.py b/surfsense_backend/app/tasks/chat/content_builder.py new file mode 100644 index 000000000..041cab286 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/content_builder.py @@ -0,0 +1,515 @@ +"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection. + +Background +---------- +The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields +SSE events that the frontend folds into a ``ContentPartsState`` (see +``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in +``stream-pipeline.ts``). When a turn ends, the frontend calls +``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]`` +JSONB to ``POST /threads/{id}/messages``, which is what was historically +written to ``new_chat_messages.content``. + +After the ghost-thread fix moved persistence server-side, the assistant +row is written by ``finalize_assistant_turn`` in the streaming finally +block. The frontend's later ``appendMessage`` is now a no-op (recovers +via the ``(thread_id, turn_id, role)`` partial unique index added in +migration 141), which means the *server* is now responsible for +producing the rich ``ContentPart[]`` shape the FE expects on history +reload — text + reasoning + tool-call cards (with ``args``, ``argsText``, +``result``, ``langchainToolCallId``) + thinking-step buckets + +step-separators. + +This module is the in-memory accumulator that mirrors the FE state for +exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*`` +/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` / +``mark_interrupted`` at the same call sites it yields the matching +``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list +stays in lockstep with what the FE's pipeline would have produced live. +``snapshot()`` is then taken once in the ``finally`` block and persisted +in a single UPDATE. + +Pure synchronous state — no DB I/O, no async, no flush callbacks. The +streaming code is responsible for driving lifecycle methods; this class +is a thin projection helper. +""" + +from __future__ import annotations + +import copy +import json +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``: +# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps +# and data-step-separator decorate the meaningful parts but never stand alone +# in a successful turn. +_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"}) + + +class AssistantContentBuilder: + """Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``. + + Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly + matches the FE ``ContentPart`` union:: + + | { type: "text"; text: string } + | { type: "reasoning"; text: string } + | { type: "tool-call"; toolCallId: str; toolName: str; + args: dict; result?: any; argsText?: str; langchainToolCallId?: str; + state?: "aborted" } + | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } } + | { type: "data-step-separator"; data: { stepIndex: int } } + + Order matches the wire order of the SSE events that drive the lifecycle + methods, with two FE-mirrored exceptions: + + 1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the + first time we see a ``data-thinking-step`` SSE event (the FE's + ``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent + thinking-step updates mutate that singleton in place. + 2. ``data-step-separator`` is appended only when the message already has + meaningful content and the previous part isn't itself a separator + (so the FIRST step of a turn doesn't generate a leading divider). + """ + + def __init__(self) -> None: + self.parts: list[dict[str, Any]] = [] + # Index of the active text/reasoning part within ``parts`` while + # streaming is open; -1 means "no active part" and the next delta + # opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``. + self._current_text_idx: int = -1 + self._current_reasoning_idx: int = -1 + # ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the + # synthetic ``call_`` (legacy) or the LangChain + # ``tool_call.id`` (parity_v2) — same key the streaming layer + # threads through every ``tool-input-*`` / ``tool-output-*`` event. + self._tool_call_idx_by_ui_id: dict[str, int] = {} + # Live argsText accumulator (concatenated ``tool-input-delta`` chunks) + # so we can reproduce the FE's ``appendToolInputDelta`` behaviour + # before ``tool-input-available`` overwrites it with the + # pretty-printed final JSON. + self._args_text_by_ui_id: dict[str, str] = {} + + # ------------------------------------------------------------------ + # Text + # ------------------------------------------------------------------ + + def on_text_start(self, text_id: str) -> None: + """Begin a fresh text block. + + Symmetric to FE ``appendText``: opening text closes any active + reasoning so the renderer treats them as separate parts. The + actual text part isn't materialised here — it's lazily created + on the first ``on_text_delta`` so an empty start/end pair + leaves no trace. Matches the FE pipeline which has no explicit + ``text-start`` handler at all. + """ + if self._current_reasoning_idx >= 0: + self._current_reasoning_idx = -1 + + def on_text_delta(self, text_id: str, delta: str) -> None: + if not delta: + return + if self._current_reasoning_idx >= 0: + # FE behaviour: a text delta after reasoning implicitly + # closes the reasoning block (see ``appendText`` lines + # 178-180). + self._current_reasoning_idx = -1 + if ( + self._current_text_idx >= 0 + and 0 <= self._current_text_idx < len(self.parts) + and self.parts[self._current_text_idx].get("type") == "text" + ): + self.parts[self._current_text_idx]["text"] += delta + return + self.parts.append({"type": "text", "text": delta}) + self._current_text_idx = len(self.parts) - 1 + + def on_text_end(self, text_id: str) -> None: + """Close the active text block. + + Mirrors the wire-level ``text-end`` boundary the streaming layer + emits before tool calls / reasoning / step boundaries. The FE + pipeline implicitly closes via ``currentTextPartIndex = -1`` + in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``; + our helper does the same explicitly so callers don't have to + maintain that invariant per call site. + """ + self._current_text_idx = -1 + + # ------------------------------------------------------------------ + # Reasoning + # ------------------------------------------------------------------ + + def on_reasoning_start(self, reasoning_id: str) -> None: + if self._current_text_idx >= 0: + self._current_text_idx = -1 + + def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None: + if not delta: + return + if self._current_text_idx >= 0: + self._current_text_idx = -1 + if ( + self._current_reasoning_idx >= 0 + and 0 <= self._current_reasoning_idx < len(self.parts) + and self.parts[self._current_reasoning_idx].get("type") == "reasoning" + ): + self.parts[self._current_reasoning_idx]["text"] += delta + return + self.parts.append({"type": "reasoning", "text": delta}) + self._current_reasoning_idx = len(self.parts) - 1 + + def on_reasoning_end(self, reasoning_id: str) -> None: + self._current_reasoning_idx = -1 + + # ------------------------------------------------------------------ + # Tool calls + # ------------------------------------------------------------------ + + def on_tool_input_start( + self, + ui_id: str, + tool_name: str, + langchain_tool_call_id: str | None, + ) -> None: + """Register a tool-call card. Args are filled in by later events.""" + if not ui_id: + return + # Skip duplicate registration: parity_v2 may emit + # ``tool-input-start`` from both ``on_chat_model_stream`` + # (when tool_call_chunks register a name) and ``on_tool_start`` + # (the canonical path). The FE de-dupes via ``toolCallIndices``; + # we mirror that here. + if ui_id in self._tool_call_idx_by_ui_id: + if langchain_tool_call_id: + idx = self._tool_call_idx_by_ui_id[ui_id] + part = self.parts[idx] + if not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + return + + part: dict[str, Any] = { + "type": "tool-call", + "toolCallId": ui_id, + "toolName": tool_name, + "args": {}, + } + if langchain_tool_call_id: + part["langchainToolCallId"] = langchain_tool_call_id + self.parts.append(part) + self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 + + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + + def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None: + """Append a streamed args-delta chunk to the matching card's argsText. + + Mirrors FE ``appendToolInputDelta``: no-ops when no card has been + registered yet for the given ``ui_id`` — the deltas have nowhere + safe to land. + """ + if not ui_id or not args_chunk: + return + idx = self._tool_call_idx_by_ui_id.get(ui_id) + if idx is None: + return + if not (0 <= idx < len(self.parts)): + return + part = self.parts[idx] + if part.get("type") != "tool-call": + return + new_text = (part.get("argsText") or "") + args_chunk + part["argsText"] = new_text + self._args_text_by_ui_id[ui_id] = new_text + + def on_tool_input_available( + self, + ui_id: str, + tool_name: str, + args: dict[str, Any], + langchain_tool_call_id: str | None, + ) -> None: + """Finalize the tool-call card's input. + + Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText`` + with ``json.dumps(input, indent=2)`` so the post-stream card renders + pretty-printed JSON, sets the full ``args`` dict, and backfills + ``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time. + Also creates the card if no prior ``tool-input-start`` registered it + (legacy parity_v2-OFF / late-registration paths). + """ + if not ui_id: + return + try: + final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False) + except (TypeError, ValueError): + # Defensive: ``args`` should already be JSON-safe (the + # streaming layer sanitizes it before emitting), but if a + # caller hands us a non-serializable value we still want + # to record the call without breaking the snapshot. + final_args_text = str(args) + + idx = self._tool_call_idx_by_ui_id.get(ui_id) + if idx is not None and 0 <= idx < len(self.parts): + part = self.parts[idx] + if part.get("type") == "tool-call": + part["args"] = args or {} + part["argsText"] = final_args_text + if langchain_tool_call_id and not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + return + + # No prior tool-input-start: register the card now. + new_part: dict[str, Any] = { + "type": "tool-call", + "toolCallId": ui_id, + "toolName": tool_name, + "args": args or {}, + "argsText": final_args_text, + } + if langchain_tool_call_id: + new_part["langchainToolCallId"] = langchain_tool_call_id + self.parts.append(new_part) + self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 + + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + + def on_tool_output_available( + self, + ui_id: str, + output: Any, + langchain_tool_call_id: str | None, + ) -> None: + """Attach the tool's output (``result``) to the matching card. + + Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId`` + only if not already set (a NULL late-arriving value never blows + away an earlier known good one). + """ + if not ui_id: + return + idx = self._tool_call_idx_by_ui_id.get(ui_id) + if idx is None or not (0 <= idx < len(self.parts)): + return + part = self.parts[idx] + if part.get("type") != "tool-call": + return + part["result"] = output + if langchain_tool_call_id and not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + + # ------------------------------------------------------------------ + # Thinking steps & step separators + # ------------------------------------------------------------------ + + def on_thinking_step( + self, + step_id: str, + title: str, + status: str, + items: list[str] | None, + ) -> None: + """Update / insert the singleton ``data-thinking-steps`` part. + + Mirrors FE ``updateThinkingSteps``: maintain a single + ``data-thinking-steps`` part anchored at index 0, replacing or + unshifting on first sight. Each ``on_thinking_step`` call + replaces the entry in the steps list keyed by ``step_id`` (or + appends if new). + """ + if not step_id: + return + + new_step = { + "id": step_id, + "title": title or "", + "status": status or "in_progress", + "items": list(items) if items else [], + } + + # Find existing data-thinking-steps part. + existing_idx = -1 + for i, p in enumerate(self.parts): + if p.get("type") == "data-thinking-steps": + existing_idx = i + break + + if existing_idx >= 0: + current_steps = self.parts[existing_idx].get("data", {}).get("steps") or [] + replaced = False + for i, step in enumerate(current_steps): + if step.get("id") == step_id: + current_steps[i] = new_step + replaced = True + break + if not replaced: + current_steps.append(new_step) + self.parts[existing_idx] = { + "type": "data-thinking-steps", + "data": {"steps": current_steps}, + } + return + + # First sight: unshift to position 0 (FE parity). + self.parts.insert( + 0, + { + "type": "data-thinking-steps", + "data": {"steps": [new_step]}, + }, + ) + # Bump tracked indices since we inserted at the head. + if self._current_text_idx >= 0: + self._current_text_idx += 1 + if self._current_reasoning_idx >= 0: + self._current_reasoning_idx += 1 + for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()): + self._tool_call_idx_by_ui_id[ui_id] = idx + 1 + + def on_step_separator(self) -> None: + """Append a ``data-step-separator`` between consecutive model steps. + + Mirrors FE ``addStepSeparator``: only emit when the message + already has meaningful content AND the previous part isn't + itself a separator. ``stepIndex`` is the running count of + separators already in ``parts``. + """ + has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts) + if not has_content: + return + if self.parts and self.parts[-1].get("type") == "data-step-separator": + return + step_index = sum( + 1 for p in self.parts if p.get("type") == "data-step-separator" + ) + self.parts.append( + { + "type": "data-step-separator", + "data": {"stepIndex": step_index}, + } + ) + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + + # ------------------------------------------------------------------ + # Interruption handling + # ------------------------------------------------------------------ + + def mark_interrupted(self) -> None: + """Close any open text/reasoning and flip running tools to aborted. + + Called from the streaming ``finally`` block before ``snapshot()`` so + the persisted JSONB reflects a coherent end-state even when the + client disconnected mid-turn or the agent hit a fatal error. + + - Active text/reasoning blocks: simply lose their "active" + marker (no synthetic content appended). Whatever was streamed + stays as-is. + - Tool-call parts that never received a ``result`` get + ``state="aborted"`` so the FE history loader can render them + as "interrupted" rather than "still running". + """ + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + for part in self.parts: + if part.get("type") != "tool-call": + continue + if "result" in part: + continue + part["state"] = "aborted" + + # ------------------------------------------------------------------ + # Snapshot & introspection + # ------------------------------------------------------------------ + + def snapshot(self) -> list[dict[str, Any]]: + """Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps. + + Deep-copied so callers that finalize from the shielded ``finally`` + block can't accidentally mutate the persisted payload while the + SQL UPDATE is in flight (the streaming layer doesn't touch the + builder after this call, but defensive copies are cheap and cheap + is what we want in a finally block). + """ + return copy.deepcopy(self.parts) + + def is_empty(self) -> bool: + """True if no meaningful content was captured. + + ``data-thinking-steps`` and ``data-step-separator`` decorate + meaningful content but don't count on their own — a turn that + only emitted a thinking step before being interrupted should + still be treated as empty for the status-marker fallback. + """ + return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts) + + def stats(self) -> dict[str, int]: + """Return counts of each part-type plus rough byte size. + + Used by the streaming layer's perf logger so an ops dashboard + can correlate finalize latency with payload size, and so a + regression that quietly stops emitting tool-call parts (or + starts emitting hundreds) shows up in [PERF] grep rather than + only as a "history reload looks weird" bug report. + + ``bytes`` is the JSON-serialised payload length — what actually + crosses the wire to PostgreSQL's JSONB column. We compute it + with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8 + on-disk layout closely enough for back-of-the-envelope sizing. + Reasoning/text/tool-call/thinking-step/step-separator counts are + independent so any one can spike without the others. + + Defensive: ``json.dumps`` failure (a non-serializable value + slipped past the streaming layer's sanitization) is reported as + ``bytes=-1`` rather than raised — perf logging must not be the + thing that breaks the streaming finally block. + """ + text_blocks = 0 + reasoning_blocks = 0 + tool_calls = 0 + tool_calls_completed = 0 + tool_calls_aborted = 0 + thinking_step_parts = 0 + step_separators = 0 + + for part in self.parts: + kind = part.get("type") + if kind == "text": + text_blocks += 1 + elif kind == "reasoning": + reasoning_blocks += 1 + elif kind == "tool-call": + tool_calls += 1 + if part.get("state") == "aborted": + tool_calls_aborted += 1 + elif "result" in part: + tool_calls_completed += 1 + elif kind == "data-thinking-steps": + thinking_step_parts += 1 + elif kind == "data-step-separator": + step_separators += 1 + + try: + byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str)) + except (TypeError, ValueError): + byte_size = -1 + + return { + "parts": len(self.parts), + "bytes": byte_size, + "text": text_blocks, + "reasoning": reasoning_blocks, + "tool_calls": tool_calls, + "tool_calls_completed": tool_calls_completed, + "tool_calls_aborted": tool_calls_aborted, + "thinking_step_parts": thinking_step_parts, + "step_separators": step_separators, + } diff --git a/surfsense_backend/app/tasks/chat/persistence.py b/surfsense_backend/app/tasks/chat/persistence.py new file mode 100644 index 000000000..b2b8b6a88 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/persistence.py @@ -0,0 +1,534 @@ +"""Server-side message persistence helpers for the streaming chat agent. + +Historically the streaming task (``stream_new_chat``/``stream_resume_chat``) +left ``new_chat_messages`` empty and relied on the frontend to round-trip +``POST /threads/{id}/messages`` afterwards. That gave authenticated clients +a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens +without leaving an audit trail. These helpers move both writes (the user +turn that triggered the stream and the assistant turn the stream produced) +into the server itself, idempotent against the partial unique index +``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do* +keep posting via ``appendMessage`` simply hit the unique-index recovery +path on the second writer instead of creating duplicates. + +Assistant turn lifecycle +------------------------ +The assistant side is split into two helpers so we can capture the row id +*before* the stream produces any output: + +* ``persist_assistant_shell`` runs immediately after ``persist_user_turn`` + and INSERTs an empty assistant row anchored to ``(thread_id, turn_id, + ASSISTANT)``. Returns the row id so the streaming layer can correlate + later writes (token_usage, AgentActionLog future-correlation) against + a stable PK from the start of the turn. +* ``finalize_assistant_turn`` runs from the streaming ``finally`` block. + It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot + produced server-side by ``AssistantContentBuilder`` and writes the + ``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against + the ``uq_token_usage_message_id`` partial unique index from migration + 142, hard-eliminating any race against ``append_message``'s recovery + branch. + +Defensive contract +------------------ + +* Every helper runs inside ``shielded_async_session()`` so ``session.close()`` + survives starlette's mid-stream cancel scope on client disconnect. +* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON + CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id, + role)`` partial unique index. On conflict the insert silently no-ops at + the DB level — no Python ``IntegrityError`` is constructed, which + eliminates spurious debugger pauses and keeps logs clean. On conflict a + follow-up ``SELECT`` resolves the existing row id so the streaming layer + can correlate writes against a stable PK. +* ``finalize_assistant_turn`` is best-effort: it never raises. The + streaming ``finally`` block calls it from within + ``anyio.CancelScope(shield=True)`` and any raised exception there + would mask the real error. +""" + +from __future__ import annotations + +import logging +import time +from datetime import UTC, datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import text as sa_text +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.future import select + +from app.db import ( + NewChatMessage, + NewChatMessageRole, + NewChatThread, + TokenUsage, + shielded_async_session, +) +from app.services.token_tracking_service import ( + TurnTokenAccumulator, +) +from app.utils.perf import get_perf_logger + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + + +# Empty initial assistant content. ``finalize_assistant_turn`` overwrites +# this in a single UPDATE at end-of-stream with the full ``ContentPart[]`` +# snapshot produced by ``AssistantContentBuilder``. We persist a one-element +# list with an empty text part so a crash between shell-INSERT and finalize +# leaves the row in a FE-renderable shape (blank bubble) instead of +# blowing up the history loader. +_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}] + +# Substituted content for genuinely empty turns (no text, no reasoning, +# no tool calls). The streaming layer flips to this when +# ``AssistantContentBuilder.is_empty()`` returns True so the persisted +# row is at least somewhat self-describing instead of an empty text +# bubble. The FE's ``ContentPart`` union doesn't include ``status`` +# yet, so the history loader will silently drop this part and render +# a blank bubble (matches today's behaviour for empty turns); a follow-up +# FE PR adds the explicit "no response" rendering. +_STATUS_NO_RESPONSE: list[dict[str, Any]] = [ + {"type": "status", "text": "(no text response)"} +] + + +def _build_user_content( + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_documents: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + """Build the persisted user-message ``content`` (assistant-ui v2 parts). + + Mirrors the shape the existing frontend posts via + ``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``): + + [{"type": "text", "text": "..."}, + {"type": "image", "image": "data:..."}, + {"type": "mentioned-documents", "documents": [{"id": int, + "title": str, "document_type": str}, ...]}] + + The companion reader is + ``app.utils.user_message_multimodal.split_persisted_user_content_parts`` + which expects exactly this shape — keep them in sync. + + ``mentioned_documents``: optional list of ``{id, title, document_type}`` + dicts. When non-empty (and a ``mentioned-documents`` part is not already + in some other input shape), a single ``{"type": "mentioned-documents", + "documents": [...]}`` part is appended. Mirrors the FE injection at + ``page.tsx:281-286`` (``persistUserTurn``). + """ + parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}] + for url in user_image_data_urls or (): + if isinstance(url, str) and url: + parts.append({"type": "image", "image": url}) + if mentioned_documents: + normalized: list[dict[str, Any]] = [] + for doc in mentioned_documents: + if not isinstance(doc, dict): + continue + doc_id = doc.get("id") + title = doc.get("title") + document_type = doc.get("document_type") + if doc_id is None or title is None or document_type is None: + continue + normalized.append( + { + "id": doc_id, + "title": str(title), + "document_type": str(document_type), + } + ) + if normalized: + parts.append({"type": "mentioned-documents", "documents": normalized}) + return parts + + +async def persist_user_turn( + *, + chat_id: int, + user_id: str | None, + turn_id: str, + user_query: str, + user_image_data_urls: list[str] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, +) -> int | None: + """Persist the user-side row for a chat turn and return its ``id``. + + Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the + ``(thread_id, turn_id, role)`` partial unique index from migration 141 + (``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops + at the DB level — no Python ``IntegrityError`` is constructed, which + eliminates the debugger pause that ``justMyCode=false`` + async greenlet + interactions used to produce, and keeps production logs clean. + + Returns the ``id`` of the row that exists for this turn after the call: + the freshly inserted ``id`` on the happy path, or the existing ``id`` + when a previous writer (legacy FE ``appendMessage`` racing the SSE + stream, redelivered request, etc.) already wrote it. Returns ``None`` + only on genuine DB failure; the caller should yield a streaming error + and abort the turn so we never produce a title/assistant row that + isn't anchored to a persisted user message. + + Other constraint violations (FK, NOT NULL, etc.) still raise + ``IntegrityError`` — only the ``(thread_id, turn_id, role)`` collision + is silenced. + """ + if not turn_id: + # Defensive: turn_id is always populated by the streaming path + # before this helper is called. If it isn't, we cannot be + # idempotent against the unique index — refuse to write rather + # than create a row the unique index can't dedupe. + logger.error( + "persist_user_turn called without a turn_id (chat_id=%s); skipping", + chat_id, + ) + return None + + t0 = time.perf_counter() + outcome = "failed" + resolved_id: int | None = None + try: + async with shielded_async_session() as ws: + # Re-attach the thread row so we can also bump updated_at + # in the same write — keeps the sidebar ordering accurate + # when a user fires off a turn but never reaches the + # legacy appendMessage. + thread = await ws.get(NewChatThread, chat_id) + author_uuid: UUID | None = None + if user_id: + try: + author_uuid = UUID(user_id) + except (TypeError, ValueError): + logger.warning( + "persist_user_turn: invalid user_id=%r, persisting as anonymous", + user_id, + ) + + content_payload = _build_user_content( + user_query, user_image_data_urls, mentioned_documents + ) + insert_stmt = ( + pg_insert(NewChatMessage) + .values( + thread_id=chat_id, + role=NewChatMessageRole.USER, + content=content_payload, + author_id=author_uuid, + turn_id=turn_id, + ) + .on_conflict_do_nothing( + index_elements=["thread_id", "turn_id", "role"], + index_where=sa_text("turn_id IS NOT NULL"), + ) + .returning(NewChatMessage.id) + ) + inserted_id = (await ws.execute(insert_stmt)).scalar() + + if inserted_id is None: + # Conflict on partial unique index — another writer + # (legacy FE appendMessage, redelivered request, etc.) + # already persisted this row. Look it up and reuse. + lookup = await ws.execute( + select(NewChatMessage.id).where( + NewChatMessage.thread_id == chat_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + existing_id = lookup.scalars().first() + if existing_id is None: + # Conflict reported but no row found — extremely + # unlikely (concurrent DELETE). Surface as failure. + logger.warning( + "persist_user_turn: conflict but no matching row " + "(chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + outcome = "integrity_no_match" + return None + resolved_id = int(existing_id) + outcome = "race_recovered" + else: + resolved_id = int(inserted_id) + outcome = "inserted" + # Bump thread.updated_at only on a real insert — when + # we recovered an existing row the prior writer + # already touched the thread. + if thread is not None: + thread.updated_at = datetime.now(UTC) + + await ws.commit() + return resolved_id + except Exception: + logger.exception( + "persist_user_turn failed (chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + return None + finally: + _perf_log.info( + "[persist_user_turn] outcome=%s chat_id=%s turn_id=%s " + "message_id=%s query_len=%d images=%d mentioned_docs=%d " + "in %.3fs", + outcome, + chat_id, + turn_id, + resolved_id, + len(user_query or ""), + len(user_image_data_urls or ()), + len(mentioned_documents or ()), + time.perf_counter() - t0, + ) + + +async def persist_assistant_shell( + *, + chat_id: int, + user_id: str | None, + turn_id: str, +) -> int | None: + """Pre-write an empty assistant row for the turn and return its id. + + Inserts a placeholder ``new_chat_messages`` row (empty text content) so + the streaming layer has a stable ``message_id`` to correlate against + for the rest of the turn. ``finalize_assistant_turn`` overwrites the + ``content`` field at end-of-stream with the rich ``ContentPart[]`` + snapshot produced by ``AssistantContentBuilder``. + + Returns the row id on success, ``None`` on a genuine DB failure (caller + should abort the turn rather than stream into a void). + + Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique + index from migration 141: if a row already exists (resume retry, racing + legacy frontend, redelivered request, etc.) we look it up by + ``(thread_id, turn_id, role)`` and return its existing id. The streaming + layer is then free to UPDATE that row at finalize time. + """ + if not turn_id: + logger.error( + "persist_assistant_shell called without a turn_id (chat_id=%s); skipping", + chat_id, + ) + return None + + t0 = time.perf_counter() + outcome = "failed" + resolved_id: int | None = None + try: + async with shielded_async_session() as ws: + insert_stmt = ( + pg_insert(NewChatMessage) + .values( + thread_id=chat_id, + role=NewChatMessageRole.ASSISTANT, + content=_EMPTY_SHELL_CONTENT, + author_id=None, + turn_id=turn_id, + ) + .on_conflict_do_nothing( + index_elements=["thread_id", "turn_id", "role"], + index_where=sa_text("turn_id IS NOT NULL"), + ) + .returning(NewChatMessage.id) + ) + inserted_id = (await ws.execute(insert_stmt)).scalar() + + if inserted_id is None: + # Conflict — another writer (legacy FE appendMessage, + # resume retry, redelivered request) wrote the + # (thread_id, turn_id, ASSISTANT) row first. Look it up + # so the streaming layer can UPDATE the same row at + # finalize time. + lookup = await ws.execute( + select(NewChatMessage.id).where( + NewChatMessage.thread_id == chat_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + existing_id = lookup.scalars().first() + if existing_id is None: + logger.warning( + "persist_assistant_shell: conflict but no matching " + "(thread_id, turn_id, role) row found " + "(chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + outcome = "integrity_no_match" + return None + resolved_id = int(existing_id) + outcome = "race_recovered" + else: + resolved_id = int(inserted_id) + outcome = "inserted" + + await ws.commit() + return resolved_id + except Exception: + logger.exception( + "persist_assistant_shell failed (chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + return None + finally: + _perf_log.info( + "[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s " + "message_id=%s in %.3fs", + outcome, + chat_id, + turn_id, + resolved_id, + time.perf_counter() - t0, + ) + + +async def finalize_assistant_turn( + *, + message_id: int, + chat_id: int, + search_space_id: int, + user_id: str | None, + turn_id: str, + content: list[dict[str, Any]], + accumulator: TurnTokenAccumulator | None, +) -> None: + """Finalize the assistant row and write its token_usage. + + Two writes in a single shielded session: + + 1. ``UPDATE new_chat_messages SET content = :c, updated_at = now() + WHERE id = :id`` — overwrites the placeholder ``persist_assistant_shell`` + wrote with the full ``ContentPart[]`` snapshot produced server-side. + 2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id) + WHERE message_id IS NOT NULL DO NOTHING`` — uses the partial unique + index ``uq_token_usage_message_id`` from migration 142 to make the + insert idempotent against ``append_message``'s recovery branch + (which uses the same ON CONFLICT clause). + + Substitutes the status-marker payload when ``content`` is empty + (pure tool-call turn that aborted before any output, or interrupt + before any event arrived). The status marker is preferable to a + blank text bubble because token accounting still runs and an ops + dashboard can flag the row. + + Best-effort — never raises. The streaming ``finally`` calls this + from within ``anyio.CancelScope(shield=True)``; any raised exception + here would mask the real error that triggered the cleanup. + """ + if not turn_id: + logger.error( + "finalize_assistant_turn called without turn_id " + "(chat_id=%s, message_id=%s); skipping", + chat_id, + message_id, + ) + return + if not message_id: + logger.error( + "finalize_assistant_turn called without message_id " + "(chat_id=%s, turn_id=%s); skipping", + chat_id, + turn_id, + ) + return + + payload: list[dict[str, Any]] + is_status_marker = False + if content: + payload = content + else: + payload = _STATUS_NO_RESPONSE + is_status_marker = True + + t0 = time.perf_counter() + outcome = "failed" + token_usage_attempted = bool( + accumulator is not None and accumulator.calls and user_id + ) + try: + async with shielded_async_session() as ws: + assistant_row = await ws.get(NewChatMessage, message_id) + if assistant_row is None: + logger.warning( + "finalize_assistant_turn: row not found " + "(chat_id=%s, message_id=%s, turn_id=%s); skipping", + chat_id, + message_id, + turn_id, + ) + outcome = "row_missing" + return + + assistant_row.content = payload + assistant_row.updated_at = datetime.now(UTC) + + # Token usage. ``record_token_usage`` (used elsewhere) does + # SELECT-then-INSERT in two statements which races with + # ``append_message``. Switch to a single INSERT ... ON + # CONFLICT DO NOTHING keyed on the migration-142 partial + # unique index so the loser silently drops its write at + # the DB level — exactly one row per ``message_id``, + # regardless of which session committed first. + if accumulator is not None and accumulator.calls and user_id: + try: + user_uuid = UUID(user_id) + except (TypeError, ValueError): + logger.warning( + "finalize_assistant_turn: invalid user_id=%r, " + "skipping token_usage row", + user_id, + ) + else: + insert_stmt = ( + pg_insert(TokenUsage) + .values( + usage_type="chat", + prompt_tokens=accumulator.total_prompt_tokens, + completion_tokens=accumulator.total_completion_tokens, + total_tokens=accumulator.grand_total, + cost_micros=accumulator.total_cost_micros, + model_breakdown=accumulator.per_message_summary(), + call_details={"calls": accumulator.serialized_calls()}, + thread_id=chat_id, + message_id=message_id, + search_space_id=search_space_id, + user_id=user_uuid, + ) + .on_conflict_do_nothing( + index_elements=["message_id"], + index_where=sa_text("message_id IS NOT NULL"), + ) + ) + await ws.execute(insert_stmt) + + await ws.commit() + outcome = "ok" + except Exception: + logger.exception( + "finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)", + chat_id, + message_id, + turn_id, + ) + finally: + _perf_log.info( + "[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s " + "turn_id=%s parts=%d status_marker=%s " + "token_usage_attempted=%s in %.3fs", + outcome, + chat_id, + message_id, + turn_id, + len(payload), + is_status_marker, + token_usage_attempted, + time.perf_counter() - t0, + ) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 8288fb75a..3ba3912eb 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -19,12 +19,12 @@ import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing import Any +from functools import partial +from typing import Any, Literal from uuid import UUID import anyio from langchain_core.messages import HumanMessage -from sqlalchemy import func from sqlalchemy.future import select from sqlalchemy.orm import selectinload @@ -33,6 +33,8 @@ from app.agents.multi_agent_chat import ( ) from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.new_chat.errors import BusyError from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( @@ -46,8 +48,11 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) -from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.middleware.busy_mutex import release_lock as _release_busy_lock +from app.agents.new_chat.middleware.busy_mutex import ( + end_turn, + get_cancel_state, + is_cancel_requested, +) from app.agents.new_chat.middleware.kb_persistence import ( commit_staged_filesystem_state, ) @@ -62,6 +67,12 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT +from app.services.auto_model_pin_service import ( + is_recently_healthy, + mark_healthy, + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -76,6 +87,60 @@ _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() logger = logging.getLogger(__name__) +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def _first_interrupt_value(state: Any) -> dict[str, Any] | None: + """Return the first LangGraph interrupt payload across all snapshot tasks.""" + + def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None: + if isinstance(candidate, dict): + value = candidate.get("value", candidate) + return value if isinstance(value, dict) else None + value = getattr(candidate, "value", None) + if isinstance(value, dict): + return value + if isinstance(candidate, (list, tuple)): + for item in candidate: + extracted = _extract_interrupt_value(item) + if extracted is not None: + return extracted + return None + + for task in getattr(state, "tasks", ()) or (): + try: + interrupts = getattr(task, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + interrupts = () + if not interrupts: + extracted = _extract_interrupt_value(task) + if extracted is not None: + return extracted + continue + for interrupt_item in interrupts: + extracted = _extract_interrupt_value(interrupt_item) + if extracted is not None: + return extracted + try: + state_interrupts = getattr(state, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + state_interrupts = () + extracted = _extract_interrupt_value(state_interrupts) + if extracted is not None: + return extracted + return None + def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. @@ -253,6 +318,19 @@ class StreamResult: verification_succeeded: bool = False commit_gate_passed: bool = True commit_gate_reason: str = "" + # Pre-allocated assistant ``new_chat_messages.id`` for this turn, + # captured by ``persist_assistant_shell`` right after the user row is + # persisted. ``None`` for the legacy / anonymous code paths that don't + # opt in to server-side ``ContentPart[]`` projection. + assistant_message_id: int | None = None + # In-memory mirror of the FE's assistant-ui ``ContentPartsState``, + # populated by the lifecycle methods called from ``_stream_agent_events`` + # at each ``streaming_service.format_*`` yield site. Snapshot in the + # streaming ``finally`` to produce the rich JSONB persisted by + # ``finalize_assistant_turn``. ``repr=False`` keeps the + # log-on-error path (``StreamResult`` is logged in some error + # branches) from dumping a potentially-large parts list. + content_builder: Any | None = field(default=None, repr=False) def _safe_float(value: Any, default: float = 0.0) -> float: @@ -285,20 +363,17 @@ def _tool_output_has_error(tool_output: Any) -> bool: return False -def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None: +def _extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: if isinstance(tool_output, dict): path_value = tool_output.get("path") if isinstance(path_value, str) and path_value.strip(): return path_value.strip() - text = _tool_output_to_text(tool_output) - if tool_name == "write_file": - match = re.search(r"Updated file\s+(.+)$", text.strip()) - if match: - return match.group(1).strip() - if tool_name == "edit_file": - match = re.search(r"in '([^']+)'", text) - if match: - return match.group(1).strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() return None @@ -344,6 +419,273 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _is_provider_rate_limited(exc: BaseException) -> bool: + """Best-effort detection for provider-side runtime throttling. + + Covers LiteLLM/OpenRouter shapes like: + - class name contains ``RateLimit`` + - nested payload ``{"error": {"code": 429}}`` + - nested payload ``{"error": {"type": "rate_limit_error"}}`` + """ + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + +_PREFLIGHT_TIMEOUT_SEC: float = 2.5 +_PREFLIGHT_MAX_TOKENS: int = 1 + + +async def _preflight_llm(llm: Any) -> None: + """Issue a minimal completion to confirm the pinned model isn't 429'ing. + + Used before agent build / planner / classifier / title-gen so a known-bad + free OpenRouter deployment is detected and repinned before it cascades + into multiple wasted internal calls. The probe is intentionally cheap: + one token, low timeout, tagged ``surfsense:internal`` so token tracking + and SSE pipelines treat it as overhead rather than user output. + + Raises the original exception when the provider responds with a + rate-limit-shaped error so the caller can drive the cooldown/repin + branch via :func:`_is_provider_rate_limited`. Other transient failures + are swallowed — the caller continues to the normal stream path and the + in-stream recovery loop remains the safety net. + """ + from litellm import acompletion + + model = getattr(llm, "model", None) + if not model or model == "auto": + # Auto-mode router doesn't have a single deployment to ping; the + # router itself handles per-deployment rate-limit accounting. + return + + try: + await acompletion( + model=model, + messages=[{"role": "user", "content": "ping"}], + api_key=getattr(llm, "api_key", None), + api_base=getattr(llm, "api_base", None), + max_tokens=_PREFLIGHT_MAX_TOKENS, + timeout=_PREFLIGHT_TIMEOUT_SEC, + stream=False, + metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]}, + ) + except Exception as exc: + if _is_provider_rate_limited(exc): + raise + logging.getLogger(__name__).debug( + "auto_pin_preflight non_rate_limit_error model=%s err=%s", + model, + exc, + ) + + +async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None: + """Wait for a discarded speculative agent build to release shared state. + + Used by the parallel preflight + agent-build path. The speculative build + closes over the request-scoped ``AsyncSession`` (for the brief connector + discovery / tool-factory window before its CPU work moves into a worker + thread). If preflight reports a 429 we want to fall back to the original + repin → reload → rebuild path, but we MUST NOT touch ``session`` again + until any in-flight session work owned by the speculative build has + fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not + concurrency-safe and the same hazard cost us a hard ``InvalidRequestError`` + earlier in this PR (see ``connector_service`` parallel-gather revert). + + We simply ``await`` the task and swallow any exception: in this path the + build's outcome is irrelevant — success populates the agent cache (a free + side effect), failure is discarded. The wasted CPU is acceptable since + 429 fallbacks are rare and the original sequential code also paid the + full build cost on the same path. + """ + with contextlib.suppress(BaseException): + await task + + +def _classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: + raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + None, + ) + + if _is_provider_rate_limited(exc): + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + None, + ) + + +def _emit_stream_terminal_error( + *, + streaming_service: VercelStreamingService, + flow: str, + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + _log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code, extra=extra) + + def _legacy_match_lc_id( pending_tool_call_chunks: list[dict[str, Any]], tool_name: str, @@ -395,6 +737,8 @@ async def _stream_agent_events( fallback_commit_created_by_id: str | None = None, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, fallback_commit_thread_id: int | None = None, + runtime_context: Any = None, + content_builder: Any | None = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -411,6 +755,15 @@ async def _stream_agent_events( initial_step_id: If set, the helper inherits an already-active thinking step. initial_step_title: Title of the inherited thinking step. initial_step_items: Items of the inherited thinking step. + content_builder: Optional ``AssistantContentBuilder``. When set, every + ``streaming_service.format_*`` yield site also drives the matching + builder lifecycle method (``on_text_*``, ``on_reasoning_*``, + ``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the + in-memory ``ContentPart[]`` projection stays in lockstep with what + the FE renders live. Pure in-memory accumulation — no DB I/O — + consumed by the streaming ``finally`` to produce the rich JSONB + persisted via ``finalize_assistant_turn``. ``None`` (the default) + is used by the anonymous / legacy code paths and is a no-op. Yields: SSE-formatted strings for each event. @@ -451,6 +804,7 @@ async def _stream_agent_events( # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + file_path_by_run: dict[str, str] = {} # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` # is keyed by the chunk's ``index`` field — LangChain @@ -474,12 +828,46 @@ async def _stream_agent_events( current_lc_tool_call_id: dict[str, str | None] = {"value": None} def _emit_tool_output(call_id: str, output: Any) -> str: + # Drive the builder before formatting the SSE so the in-memory + # ContentPart[] mirror sees the result attached to the same + # card the FE will render. Builder method is a no-op when + # ``content_builder`` is None (anonymous / legacy paths). + if content_builder is not None: + content_builder.on_tool_output_available( + call_id, output, current_lc_tool_call_id["value"] + ) return streaming_service.format_tool_output_available( call_id, output, langchain_tool_call_id=current_lc_tool_call_id["value"], ) + def _emit_thinking_step( + *, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + ) -> str: + """Format a thinking-step SSE event and notify the builder. + + Single helper used at every ``format_thinking_step`` yield site + in this generator. Drives ``AssistantContentBuilder.on_thinking_step`` + first so the FE-mirror state lands the update before the SSE + carrying the same data leaves the wire — order matches the FE + pipeline (``processSharedStreamEvent`` updates state, then + flushes). Builder call is a no-op when ``content_builder`` is + None (anonymous / legacy paths). + """ + if content_builder is not None: + content_builder.on_thinking_step(step_id, title, status, items) + return streaming_service.format_thinking_step( + step_id=step_id, + title=title, + status=status, + items=items, + ) + def next_thinking_step_id() -> str: nonlocal thinking_step_counter thinking_step_counter += 1 @@ -489,7 +877,7 @@ async def _stream_agent_events( nonlocal last_active_step_id if last_active_step_id and last_active_step_id not in completed_step_ids: completed_step_ids.add(last_active_step_id) - event = streaming_service.format_thinking_step( + event = _emit_thinking_step( step_id=last_active_step_id, title=last_active_step_title, status="completed", @@ -499,7 +887,18 @@ async def _stream_agent_events( return event return None - async for event in agent.astream_events(input_data, config=config, version="v2"): + # Per-invocation runtime context (Phase 1.5). When supplied, + # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids`` + # from ``runtime.context`` instead of its constructor closure — the + # prerequisite that lets the compiled-agent cache (Phase 1) reuse a + # single graph across turns. Astream_events_kwargs stays empty when + # callers leave ``runtime_context`` as ``None`` to preserve the + # legacy code path bit-for-bit. + astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} + if runtime_context is not None: + astream_kwargs["context"] = runtime_context + + async for event in agent.astream_events(input_data, **astream_kwargs): event_type = event.get("event", "") if event_type == "on_chat_model_stream": @@ -523,6 +922,8 @@ async def _stream_agent_events( if parity_v2 and reasoning_delta: if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if current_reasoning_id is None: completion_event = complete_current_step() @@ -535,13 +936,21 @@ async def _stream_agent_events( just_finished_tool = False current_reasoning_id = streaming_service.generate_reasoning_id() yield streaming_service.format_reasoning_start(current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_start(current_reasoning_id) yield streaming_service.format_reasoning_delta( current_reasoning_id, reasoning_delta ) + if content_builder is not None: + content_builder.on_reasoning_delta( + current_reasoning_id, reasoning_delta + ) if text_delta: if current_reasoning_id is not None: yield streaming_service.format_reasoning_end(current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_end(current_reasoning_id) current_reasoning_id = None if current_text_id is None: completion_event = complete_current_step() @@ -554,8 +963,12 @@ async def _stream_agent_events( just_finished_tool = False current_text_id = streaming_service.generate_text_id() yield streaming_service.format_text_start(current_text_id) + if content_builder is not None: + content_builder.on_text_start(current_text_id) yield streaming_service.format_text_delta(current_text_id, text_delta) accumulated_text += text_delta + if content_builder is not None: + content_builder.on_text_delta(current_text_id, text_delta) # Live tool-call argument streaming. Runs AFTER text/reasoning # processing so chunks containing both stay in their natural @@ -587,11 +1000,17 @@ async def _stream_agent_events( # within the same stream window. if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if current_reasoning_id is not None: yield streaming_service.format_reasoning_end( current_reasoning_id ) + if content_builder is not None: + content_builder.on_reasoning_end( + current_reasoning_id + ) current_reasoning_id = None index_to_meta[idx] = { @@ -604,6 +1023,8 @@ async def _stream_agent_events( name, langchain_tool_call_id=lc_id, ) + if content_builder is not None: + content_builder.on_tool_input_start(ui_id, name, lc_id) # Emit args delta for any chunk at a registered # index (including idless continuations). Once an @@ -619,6 +1040,10 @@ async def _stream_agent_events( yield streaming_service.format_tool_input_delta( meta["ui_id"], args_chunk ) + if content_builder is not None: + content_builder.on_tool_input_delta( + meta["ui_id"], args_chunk + ) else: pending_tool_call_chunks.append(tcc) @@ -629,9 +1054,15 @@ async def _stream_agent_events( tool_input = event.get("data", {}).get("input", {}) if tool_name in ("write_file", "edit_file"): result.write_attempted = True + if isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip() and run_id: + file_path_by_run[run_id] = file_path.strip() if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if last_active_step_title != "Synthesizing response": @@ -652,7 +1083,7 @@ async def _stream_agent_events( ) last_active_step_title = "Listing files" last_active_step_items = [ls_path] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Listing files", status="in_progress", @@ -667,7 +1098,7 @@ async def _stream_agent_events( display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Reading file" last_active_step_items = [display_fp] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Reading file", status="in_progress", @@ -682,7 +1113,7 @@ async def _stream_agent_events( display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Writing file" last_active_step_items = [display_fp] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Writing file", status="in_progress", @@ -697,7 +1128,7 @@ async def _stream_agent_events( display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Editing file" last_active_step_items = [display_fp] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Editing file", status="in_progress", @@ -714,7 +1145,7 @@ async def _stream_agent_events( ) last_active_step_title = "Searching files" last_active_step_items = [f"{pat} in {base_path}"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Searching files", status="in_progress", @@ -734,7 +1165,7 @@ async def _stream_agent_events( last_active_step_items = [ f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "") ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Searching content", status="in_progress", @@ -749,7 +1180,7 @@ async def _stream_agent_events( display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] last_active_step_title = "Deleting file" last_active_step_items = [display_path] if display_path else [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Deleting file", status="in_progress", @@ -766,7 +1197,7 @@ async def _stream_agent_events( ) last_active_step_title = "Deleting folder" last_active_step_items = [display_path] if display_path else [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Deleting folder", status="in_progress", @@ -783,7 +1214,7 @@ async def _stream_agent_events( ) last_active_step_title = "Creating folder" last_active_step_items = [display_path] if display_path else [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Creating folder", status="in_progress", @@ -806,7 +1237,7 @@ async def _stream_agent_events( last_active_step_items = ( [f"{display_src} → {display_dst}"] if src or dst else [] ) - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Moving file", status="in_progress", @@ -823,7 +1254,7 @@ async def _stream_agent_events( if todo_count else [] ) - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Planning tasks", status="in_progress", @@ -838,7 +1269,7 @@ async def _stream_agent_events( display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") last_active_step_title = "Saving document" last_active_step_items = [display_title] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Saving document", status="in_progress", @@ -854,7 +1285,7 @@ async def _stream_agent_events( last_active_step_items = [ f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}" ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Generating image", status="in_progress", @@ -870,7 +1301,7 @@ async def _stream_agent_events( last_active_step_items = [ f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Scraping webpage", status="in_progress", @@ -893,7 +1324,7 @@ async def _stream_agent_events( f"Content: {content_len:,} characters", "Preparing audio generation...", ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Generating podcast", status="in_progress", @@ -914,7 +1345,7 @@ async def _stream_agent_events( f"Topic: {report_topic}", "Analyzing source content...", ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title=step_title, status="in_progress", @@ -929,7 +1360,7 @@ async def _stream_agent_events( display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") last_active_step_title = "Running command" last_active_step_items = [f"$ {display_cmd}"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Running command", status="in_progress", @@ -946,7 +1377,7 @@ async def _stream_agent_events( tool_name.replace("_", " ").strip().capitalize() or tool_name ) last_active_step_items = [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title=last_active_step_title, status="in_progress", @@ -1007,6 +1438,10 @@ async def _stream_agent_events( tool_name, langchain_tool_call_id=langchain_tool_call_id, ) + if content_builder is not None: + content_builder.on_tool_input_start( + tool_call_id, tool_name, langchain_tool_call_id + ) if run_id: ui_tool_call_id_by_run[run_id] = tool_call_id @@ -1029,12 +1464,20 @@ async def _stream_agent_events( _safe_input, langchain_tool_call_id=langchain_tool_call_id, ) + if content_builder is not None: + content_builder.on_tool_input_available( + tool_call_id, + tool_name, + _safe_input, + langchain_tool_call_id, + ) elif event_type == "on_tool_end": active_tool_depth = max(0, active_tool_depth - 1) run_id = event.get("run_id", "") tool_name = event.get("name", "unknown_tool") raw_output = event.get("data", {}).get("output", "") + staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None if tool_name == "update_memory": called_update_memory = True @@ -1100,70 +1543,70 @@ async def _stream_agent_events( current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] if tool_name == "read_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Reading file", status="completed", items=last_active_step_items, ) elif tool_name == "write_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Writing file", status="completed", items=last_active_step_items, ) elif tool_name == "edit_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Editing file", status="completed", items=last_active_step_items, ) elif tool_name == "glob": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Searching files", status="completed", items=last_active_step_items, ) elif tool_name == "grep": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Searching content", status="completed", items=last_active_step_items, ) elif tool_name == "rm": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Deleting file", status="completed", items=last_active_step_items, ) elif tool_name == "rmdir": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Deleting folder", status="completed", items=last_active_step_items, ) elif tool_name == "mkdir": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Creating folder", status="completed", items=last_active_step_items, ) elif tool_name == "move_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Moving file", status="completed", items=last_active_step_items, ) elif tool_name == "write_todos": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Planning tasks", status="completed", @@ -1180,7 +1623,7 @@ async def _stream_agent_events( *last_active_step_items, result_str[:80] if is_error else "Saved to knowledge base", ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Saving document", status="completed", @@ -1199,7 +1642,7 @@ async def _stream_agent_events( else "Generation failed" ) completed_items = [*last_active_step_items, f"Error: {error_msg}"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Generating image", status="completed", @@ -1223,7 +1666,7 @@ async def _stream_agent_events( ] else: completed_items = [*last_active_step_items, "Content extracted"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Scraping webpage", status="completed", @@ -1240,10 +1683,10 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else "Podcast" ) - if podcast_status == "processing": + if podcast_status in ("pending", "generating", "processing"): completed_items = [ f"Title: {podcast_title}", - "Audio generation started", + "Podcast generation started", "Processing in background...", ] elif podcast_status == "already_generating": @@ -1252,7 +1695,7 @@ async def _stream_agent_events( "Podcast already in progress", "Please wait for it to complete", ] - elif podcast_status == "error": + elif podcast_status in ("failed", "error"): error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) @@ -1262,9 +1705,14 @@ async def _stream_agent_events( f"Title: {podcast_title}", f"Error: {error_msg[:50]}", ] + elif podcast_status in ("ready", "success"): + completed_items = [ + f"Title: {podcast_title}", + "Podcast ready", + ] else: completed_items = last_active_step_items - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Generating podcast", status="completed", @@ -1299,7 +1747,7 @@ async def _stream_agent_events( ] else: completed_items = last_active_step_items - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Generating video presentation", status="completed", @@ -1347,7 +1795,7 @@ async def _stream_agent_events( else: completed_items = last_active_step_items - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title=step_title, status="completed", @@ -1373,7 +1821,7 @@ async def _stream_agent_events( ] else: completed_items = [*last_active_step_items, "Finished"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Running command", status="completed", @@ -1413,7 +1861,7 @@ async def _stream_agent_events( completed_items.append(f"(+{len(file_names) - 4} more)") else: completed_items = ["No files found"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Listing files", status="completed", @@ -1425,7 +1873,7 @@ async def _stream_agent_events( fallback_title = ( tool_name.replace("_", " ").strip().capitalize() or tool_name ) - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title=fallback_title, status="completed", @@ -1444,20 +1892,28 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else {"result": tool_output}, ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "success" + if isinstance(tool_output, dict) and tool_output.get("status") in ( + "pending", + "generating", + "processing", + ): + yield streaming_service.format_terminal_info( + f"Podcast queued: {tool_output.get('title', 'Podcast')}", + "success", + ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "ready", + "success", ): yield streaming_service.format_terminal_info( f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", "success", ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "failed", + "error", + ): + error_msg = tool_output.get("error", "Unknown error") yield streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", @@ -1548,6 +2004,9 @@ async def _stream_agent_events( resolved_path = _extract_resolved_file_path( tool_name=tool_name, tool_output=tool_output, + tool_input={"file_path": staged_file_path} + if staged_file_path + else None, ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): @@ -1754,7 +2213,7 @@ async def _stream_agent_events( # Phase transitions: replace everything after topic last_active_step_items = [*topic_items, message] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=last_active_step_id, title=last_active_step_title, status="in_progress", @@ -1796,10 +2255,14 @@ async def _stream_agent_events( elif event_type in ("on_chain_end", "on_agent_end"): if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) completion_event = complete_current_step() if completion_event: @@ -1884,8 +2347,14 @@ async def _stream_agent_events( ) gate_text_id = streaming_service.generate_text_id() yield streaming_service.format_text_start(gate_text_id) + if content_builder is not None: + content_builder.on_text_start(gate_text_id) yield streaming_service.format_text_delta(gate_text_id, gate_notice) + if content_builder is not None: + content_builder.on_text_delta(gate_text_id, gate_notice) yield streaming_service.format_text_end(gate_text_id) + if content_builder is not None: + content_builder.on_text_end(gate_text_id) yield streaming_service.format_terminal_info(gate_notice, "error") accumulated_text = gate_notice else: @@ -1896,19 +2365,11 @@ async def _stream_agent_events( result.agent_called_update_memory = called_update_memory _log_file_contract("turn_outcome", result) - snapshot_interrupts = getattr(state, "interrupts", ()) or () - interrupt_value = None - if snapshot_interrupts: - interrupt_value = snapshot_interrupts[0].value - else: - for task in state.tasks or []: - if task.interrupts: - interrupt_value = task.interrupts[0].value - break + interrupt_value = _first_interrupt_value(state) if interrupt_value is not None: result.is_interrupted = True result.interrupt_value = interrupt_value - yield streaming_service.format_interrupt_request(interrupt_value) + yield streaming_service.format_interrupt_request(result.interrupt_value) async def stream_new_chat( @@ -1919,6 +2380,7 @@ async def stream_new_chat( llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, mentioned_surfsense_doc_ids: list[int] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, checkpoint_id: str | None = None, needs_history_bootstrap: bool = False, thread_visibility: ChatVisibility | None = None, @@ -1927,6 +2389,7 @@ async def stream_new_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, + flow: Literal["new", "regenerate"] = "new", ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -1974,14 +2437,26 @@ async def stream_new_chat( accumulator = start_turn() - # Premium quota tracking state - _premium_reserved = 0 + # Premium credit (USD micro-units) tracking state. Stores the + # amount reserved up front so we can release it on cancellation + # and finalize-debit the actual provider cost reported by LiteLLM. + _premium_reserved_micros = 0 _premium_request_id: str | None = None # ``BusyError`` fires before the lock is acquired; the ``finally`` must # not release the in-flight caller's lock. _busy_error_raised = False + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow=flow, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration @@ -1989,87 +2464,280 @@ async def stream_new_chat( await set_ai_responding(session, chat_id, UUID(user_id)) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) _t0 = time.perf_counter() - if llm_config_id >= 0: - # Positive ID: Load from NewLLMConfig database table - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, + # Image-bearing turns force the Auto-pin resolver to filter the + # candidate pool to vision-capable cfgs (and force-repin a + # text-only existing pin). For explicit selections this flag is + # a no-op — the resolver returns the user's chosen id unchanged. + _requires_image_input = bool(user_image_data_urls) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + requires_image_input=_requires_image_input, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + # Auto-pin's "no vision-capable cfg" path raises a ValueError + # whose message we map to the friendly image-input SSE error + # so the user sees the same message regardless of whether + # the gate fired in Auto-mode or in the agent_config check + # below. + error_code = ( + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + if _requires_image_input and "vision-capable" in str(pin_error) + else "SERVER_ERROR" ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return + error_kind = ( + "user_error" + if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + else "server_error" + ) + yield _emit_stream_error( + message=str(pin_error), + error_kind=error_kind, + error_code=error_code, + ) + yield streaming_service.format_done() + return - # Create ChatLiteLLM from AgentConfig - llm = create_chat_litellm_from_agent_config(agent_config) - else: - # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from global config dict - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return _perf_log.info( "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", time.perf_counter() - _t0, llm_config_id, ) - # Premium quota reservation — applies to explicitly premium configs - # AND Auto mode (which may route to premium models). + # Capability safety net: a turn carrying user-uploaded images + # cannot be routed to a chat config that LiteLLM's authoritative + # model map *explicitly* marks as text-only (``supports_vision`` + # set to False). The check is intentionally narrow — it only + # fires when LiteLLM is *certain* the model can't accept image + # input. Unknown / unmapped / vision-capable models pass + # through. Without this guard a known-text-only model would 404 + # at the provider with ``"No endpoints found that support image + # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk; + # failing here lets us return a friendly message that tells the + # user what to change. + if user_image_data_urls and agent_config is not None: + from app.services.provider_capabilities import ( + is_known_text_only_chat_model, + ) + + agent_litellm_params = agent_config.litellm_params or {} + agent_base_model = ( + agent_litellm_params.get("base_model") + if isinstance(agent_litellm_params, dict) + else None + ) + if is_known_text_only_chat_model( + provider=agent_config.provider, + model_name=agent_config.model_name, + base_model=agent_base_model, + custom_provider=agent_config.custom_provider, + ): + model_label = ( + agent_config.config_name or agent_config.model_name or "model" + ) + yield _emit_stream_error( + message=( + f"The selected model ({model_label}) does not support " + "image input. Switch to a vision-capable model " + "(e.g. GPT-4o, Claude, Gemini) or remove the image " + "attachment and try again." + ), + error_kind="user_error", + error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + ) + yield streaming_service.format_done() + return + + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( - agent_config is not None - and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + agent_config is not None and user_id and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _agent_litellm_params = agent_config.litellm_params or {} + _agent_base_model = ( + _agent_litellm_params.get("base_model") or agent_config.model_name or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_agent_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _premium_reserved = reserve_amount + _premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + requires_image_input=_requires_image_input, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + _premium_request_id = None + _premium_reserved_micros = 0 + _log_chat_stream_error( + flow=flow, + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, + ) + else: + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return - # Auto mode: quota exhausted but we can still proceed - # (the router may pick a free model). Reset reservation. - _premium_request_id = None - _premium_reserved = 0 if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return + # Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs + # (negative ids selected via ``resolve_or_get_pinned_llm_config_id``) + # whose health hasn't already been confirmed within the TTL window. + # Detecting a 429 here lets us repin BEFORE the planner/classifier/ + # title-generation LLM calls fan out and each independently hit the + # same upstream rate limit. + # + # PERF: preflight is a network round-trip to the LLM provider (~1-5s) + # and is independent of the agent build (CPU-bound, ~5-7s). They used + # to run sequentially → ``preflight + build`` on cold cache = 11.5s. + # We now kick off preflight as a background task FIRST, then run the + # synchronous setup work and the agent build in parallel. In the + # success path (the common case) total wall time drops to roughly + # ``max(preflight, build)`` — the preflight finishes during the + # agent compile and we just consume its result. In the rare 429 + # path the speculative build is awaited to completion (so its + # session usage is fully released) via + # :func:`_settle_speculative_agent_build`, then discarded, and + # we fall back to the original repin-and-rebuild flow. + preflight_needed = ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ) + preflight_task: asyncio.Task[None] | None = None + _t_preflight = 0.0 + if preflight_needed: + _t_preflight = time.perf_counter() + preflight_task = asyncio.create_task( + _preflight_llm(llm), + name=f"auto_pin_preflight:{llm_config_id}", + ) + # Create connector service _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) @@ -2098,24 +2766,17 @@ async def stream_new_chat( use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) _t0 = time.perf_counter() - if use_multi_agent: - agent = await create_registry_deep_agent( - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - mentioned_document_ids=mentioned_document_ids, - disabled_tools=disabled_tools, - ) - else: - agent = await create_surfsense_deep_agent( + agent_factory = ( + create_registry_deep_agent + if use_multi_agent + else create_surfsense_deep_agent + ) + # Speculative agent build — runs in parallel with the preflight + # task (if any). Built with the *current* ``llm`` / ``agent_config``; + # if preflight reports 429 we will discard this future and rebuild + # against the freshly pinned config below. + agent_build_task = asyncio.create_task( + agent_factory( llm=llm, search_space_id=search_space_id, db_session=session, @@ -2129,7 +2790,116 @@ async def stream_new_chat( disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, filesystem_selection=filesystem_selection, - ) + ), + name="agent_build:stream_new_chat", + ) + + agent: Any = None + if preflight_task is not None: + try: + await preflight_task + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + # Both branches below need the session: the non-429 path + # may unwind via cleanup that uses ``session``, and the + # 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id`` + # against it. Wait for the speculative build to release its + # session usage before we proceed. + await _settle_speculative_agent_build(agent_build_task) + if not _is_provider_rate_limited(preflight_exc): + raise + # 429: speculative agent is discarded; run the original + # repin → reload → rebuild path against the freshly + # pinned config. + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + # Trust the freshly-resolved cfg for the remainder of this + # turn rather than recursing into another preflight; the + # in-stream 429 recovery loop is still in place as the + # safety net if even this fallback hits an upstream cap. + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + # Rebuild against the new llm/agent_config. Sequential + # here because we no longer have anything to overlap with. + agent = await agent_factory( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) + + if agent is None: + # Either no preflight was needed, or preflight succeeded — + # in both cases the speculative build is the agent we want. + agent = await agent_build_task _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 ) @@ -2301,6 +3071,97 @@ async def stream_new_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) + + # Persist the user-side row for this turn before any expensive + # work runs. Closes the "ghost-thread" abuse vector + # (authenticated client hits POST /new_chat then never calls + # /messages — empty new_chat_messages, free LLM completion). + # Idempotent against the unique index in migration 141 so the + # legacy frontend appendMessage call is a no-op on the second + # writer. Hard failure aborts the turn so we never produce a + # title or assistant row that isn't anchored to a persisted + # user message. + from app.tasks.chat.content_builder import AssistantContentBuilder + from app.tasks.chat.persistence import ( + persist_assistant_shell, + persist_user_turn, + ) + + user_message_id = await persist_user_turn( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + mentioned_documents=mentioned_documents, + ) + if user_message_id is None: + yield _emit_stream_error( + message=( + "We couldn't save your message. Please try again in a moment." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + # Emit canonical user message id BEFORE any LLM streaming so the + # FE can rename its optimistic ``msg-user-XXX`` placeholder to + # ``msg-{user_message_id}`` and unlock features gated on a real + # DB id (comments, edit-from-this-message). See B4 in + # ``sse-based_message_id_handshake`` plan. + yield streaming_service.format_data( + "user-message-id", + {"message_id": user_message_id, "turn_id": stream_result.turn_id}, + ) + + # Pre-write the assistant row for this turn so we have a stable + # ``message_id`` to anchor mid-stream metadata (token_usage, + # future agent_action_log.message_id correlation) and a + # write-once UPDATE target at finalize time. Idempotent against + # the (thread_id, turn_id, ASSISTANT) partial unique index from + # migration 141 — if the legacy frontend appendMessage races + # this, we recover the existing row's id. + assistant_message_id = await persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + ) + if assistant_message_id is None: + # Genuine DB failure — abort the turn rather than stream + # into a void. The user row is already persisted so the + # legacy "ghost-thread" gate isn't reopened. + yield _emit_stream_error( + message=( + "We couldn't initialize the assistant message. Please try again." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + # Emit canonical assistant message id BEFORE any LLM streaming + # so the FE can rename its optimistic ``msg-assistant-XXX`` + # placeholder to ``msg-{assistant_message_id}`` and bind + # ``tokenUsageStore`` / ``pendingInterrupt`` to the real id + # immediately. See B4 in ``sse-based_message_id_handshake`` + # plan. + yield streaming_service.format_data( + "assistant-message-id", + {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, + ) + + stream_result.assistant_message_id = assistant_message_id + stream_result.content_builder = AssistantContentBuilder() # Initial thinking step - analyzing the request if mentioned_surfsense_docs: @@ -2334,6 +3195,15 @@ async def stream_new_chat( initial_items = [f"{action_verb}: {' '.join(processing_parts)}"] initial_step_id = "thinking-1" + # Drive the builder for this initial thinking step too — the + # ``_emit_thinking_step`` helper lives inside ``_stream_agent_events`` + # so it isn't in scope here, but the FE folds this step into + # the same singleton ``data-thinking-steps`` part as everything + # the agent stream emits later. Mirror that fold server-side. + if stream_result.content_builder is not None: + stream_result.content_builder.on_thinking_step( + initial_step_id, initial_title, "in_progress", initial_items + ) yield streaming_service.format_thinking_step( step_id=initial_step_id, title=initial_title, @@ -2350,16 +3220,34 @@ async def stream_new_chat( # Check if this is the first assistant response so we can generate # a title in parallel with the agent stream (better UX than waiting # until after the full response). - assistant_count_result = await session.execute( - select(func.count(NewChatMessage.id)).filter( + # Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because + # this is now a hot path executed on every turn, and COUNT scales + # with thread length (server-side persistence can grow rows + # quickly under power users). + # + # IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already + # inserted THIS turn's assistant row. We must therefore exclude + # it from the probe — otherwise the gate fires on every turn + # except the very first, and title generation never runs for new + # threads. Excluding by primary key (``id != assistant_message_id``) + # is bulletproof regardless of ``turn_id`` shape (legacy NULLs, + # resume turns, etc.). + first_assistant_probe = await session.execute( + select(NewChatMessage.id) + .filter( NewChatMessage.thread_id == chat_id, NewChatMessage.role == "assistant", + NewChatMessage.id != assistant_message_id, ) + .limit(1) ) - is_first_response = (assistant_count_result.scalar() or 0) == 0 + is_first_response = first_assistant_probe.scalars().first() is None title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None - if is_first_response: + # Gate title generation on a persisted user message so a stream + # that fails before persistence (we abort above) can never leave + # behind a thread with a generated title and no anchoring rows. + if is_first_response and user_message_id is not None: async def _generate_title() -> tuple[str | None, dict | None]: """Generate a short title via litellm.acompletion. @@ -2375,6 +3263,7 @@ async def stream_new_chat( from litellm import acompletion from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator _turn_accumulator.set(None) @@ -2395,11 +3284,32 @@ async def stream_new_chat( model="auto", messages=messages ) else: + # Apply the same ``api_base`` cascade chat / vision / + # image-gen call sites use so we never inherit + # ``litellm.api_base`` (commonly set by + # ``AZURE_OPENAI_ENDPOINT``) when the chat config + # itself ships an empty ``api_base``. Without this + # the title-gen on an OpenRouter chat config would + # 404 against the inherited Azure endpoint — see + # ``provider_api_base`` docstring for the same + # bug repro on the image-gen / vision paths. + raw_model = getattr(llm, "model", "") or "" + provider_prefix = ( + raw_model.split("/", 1)[0] if "/" in raw_model else None + ) + provider_value = ( + agent_config.provider if agent_config is not None else None + ) + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) response = await acompletion( - model=llm.model, + model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=getattr(llm, "api_base", None), + api_base=title_api_base, ) usage_info = None @@ -2433,56 +3343,172 @@ async def stream_new_chat( title_emitted = False + # Build the per-invocation runtime context (Phase 1.5). + # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware`` + # via ``runtime.context.mentioned_document_ids`` instead of its + # ``__init__`` closure — that way the same compiled-agent instance + # can serve multiple turns with different mention lists. + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + mentioned_document_ids=list(mentioned_document_ids or []), + request_id=request_id, + turn_id=stream_result.turn_id, + ) + _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background task finishes - if title_task is not None and title_task.done() and not title_emitted: - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=input_state, + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking", + initial_step_id=initial_step_id, + initial_step_title=initial_title, + initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, + content_builder=stream_result.content_builder, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title + _first_event_logged = True + yield sse + + # Inject title update mid-stream as soon as the background + # task finishes. + if ( + title_task is not None + and title_task.done() + and not title_emitted + ): + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter( + NewChatThread.id == chat_id + ) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update( + chat_id, generated_title + ) + title_emitted = True + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) + ) + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + # The failed attempt may still hold the per-thread busy mutex + # (middleware teardown can lag behind raised provider errors). + # Force release before we retry within the same request. + end_turn(str(chat_id)) + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) - title_emitted = True + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error: + raise stream_exc + + # Title generation uses the initial llm object. After a runtime + # repin we keep the stream focused on response recovery and skip + # title generation for this turn. + if title_task is not None and not title_task.done(): + title_task.cancel() + title_task = None + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_new_chat] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", @@ -2497,9 +3523,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -2510,6 +3537,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -2537,29 +3565,25 @@ async def stream_new_chat( chat_id, generated_title ) - # Finalize premium quota with actual tokens. - # For Auto mode, only count tokens from calls that used premium models. + # Finalize premium credit debit with the actual provider cost + # reported by LiteLLM, summed across every call in the turn. + # Mirrors the pre-cost behaviour of "premium turn → all calls + # count" so free sub-agent calls during a premium turn still + # contribute to the bill (they're $0 in practice anyway). if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=actual_premium_tokens, - reserved_tokens=_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_premium_reserved_micros, ) + _premium_request_id = None + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -2569,9 +3593,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal new_chat: calls=%d total=%d summary=%s", + "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -2582,6 +3607,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -2617,13 +3643,7 @@ async def stream_new_chat( task.add_done_callback(_background_tasks.discard) # Finish the step and message - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except BusyError as e: - _busy_error_raised = True - yield streaming_service.format_error(str(e)) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2632,12 +3652,41 @@ async def stream_new_chat( # Handle any errors import traceback + # ``BusyError`` fires before the agent acquires the lock; the + # cleanup path must skip lock release to avoid freeing the + # in-flight caller's lock. Classification is handled below. + if isinstance(e, BusyError): + _busy_error_raised = True + + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + error_extra, + ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + extra=error_extra, + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2653,8 +3702,12 @@ async def stream_new_chat( # (CancelledError is a BaseException), and the rest of the # finally block — including session.close() — would never run. with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized - if _premium_request_id and _premium_reserved > 0 and user_id: + if _premium_request_id and _premium_reserved_micros > 0 and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -2662,9 +3715,9 @@ async def stream_new_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_premium_reserved, + reserved_micros=_premium_reserved_micros, ) - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s", user_id @@ -2688,6 +3741,81 @@ async def stream_new_chat( with contextlib.suppress(Exception): await session.close() + # Server-side assistant-message + token_usage finalization. + # Runs after the main session has been closed (uses its own + # shielded session) so we don't fight the same DB connection. + # Idempotent against the legacy frontend appendMessage: + # * the assistant row was already INSERTed by + # ``persist_assistant_shell`` above, so this just UPDATEs + # it with the rich ContentPart[] from the builder. + # * token_usage uses INSERT ... ON CONFLICT DO NOTHING + # against migration 142's partial unique index, so a + # racing append_message recovery branch can never + # double-write. + # ``mark_interrupted`` closes any open text/reasoning blocks + # and flips running tool-calls (no result) to state=aborted + # so the persisted JSONB reflects a coherent end-state even + # on client disconnect. + # Never raises (best-effort, logs only). + if ( + stream_result + and stream_result.turn_id + and stream_result.assistant_message_id + ): + from app.tasks.chat.persistence import finalize_assistant_turn + + builder_stats: dict[str, int] | None = None + if stream_result.content_builder is not None: + stream_result.content_builder.mark_interrupted() + # Snapshot stats BEFORE deepcopy in ``snapshot()`` so + # the perf log records the actual finalised payload + # (post-mark_interrupted), not the live-mutating + # builder state. + builder_stats = stream_result.content_builder.stats() + content_payload = stream_result.content_builder.snapshot() + else: + # Defensive fallback — we always set the builder + # alongside ``assistant_message_id`` above, so this + # branch only fires if a future refactor ever + # decouples them. Persist whatever accumulated + # text we captured so the row at least renders. + content_payload = [ + { + "type": "text", + "text": stream_result.accumulated_text or "", + } + ] + + if builder_stats is not None: + _perf_log.info( + "[stream_new_chat] finalize_payload chat_id=%s " + "message_id=%s parts=%d bytes=%d text=%d " + "reasoning=%d tool_calls=%d " + "tool_calls_completed=%d tool_calls_aborted=%d " + "thinking_step_parts=%d step_separators=%d", + chat_id, + stream_result.assistant_message_id, + builder_stats["parts"], + builder_stats["bytes"], + builder_stats["text"], + builder_stats["reasoning"], + builder_stats["tool_calls"], + builder_stats["tool_calls_completed"], + builder_stats["tool_calls_aborted"], + builder_stats["thinking_step_parts"], + builder_stats["step_separators"], + ) + + await finalize_assistant_turn( + message_id=stream_result.assistant_message_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + turn_id=stream_result.turn_id, + content=content_payload, + accumulator=accumulator, + ) + # Persist any sandbox-produced files to local storage so they # remain downloadable after the Daytona sandbox auto-deletes. if stream_result and stream_result.sandbox_files: @@ -2707,11 +3835,11 @@ async def stream_new_chat( # Skip on ``BusyError`` (caller never acquired the lock). if not _busy_error_raised: with contextlib.suppress(Exception): - if _release_busy_lock(str(chat_id)): - _perf_log.info( - "[stream_new_chat] released stale busy lock (chat_id=%s)", - chat_id, - ) + end_turn(str(chat_id)) + _perf_log.info( + "[stream_new_chat] end_turn cleanup (chat_id=%s)", + chat_id, + ) # Break circular refs held by the agent graph, tools, and LLM # wrappers so the GC can reclaim them in a single pass. @@ -2765,83 +3893,218 @@ async def stream_resume_chat( # Skip the finally release on ``BusyError`` (caller never acquired the lock). _busy_error_raised = False + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow="resume", + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) + session = async_session_maker() try: if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None - _t0 = time.perf_counter() - if llm_config_id >= 0: - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" + + _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_agent_config(agent_config) - else: - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return _perf_log.info( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) - # Premium quota reservation (same logic as stream_new_chat) - _resume_premium_reserved = 0 + # Premium credit reservation (same logic as stream_new_chat). + _resume_premium_reserved_micros = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( - agent_config is not None - and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + agent_config is not None and user_id and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _resume_premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _resume_litellm_params = agent_config.litellm_params or {} + _resume_base_model = ( + _resume_litellm_params.get("base_model") + or agent_config.model_name + or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_resume_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _resume_premium_reserved = reserve_amount + _resume_premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + _resume_premium_request_id = None + _resume_premium_reserved_micros = 0 + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, + ) + else: + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return - _resume_premium_request_id = None - _resume_premium_reserved = 0 if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return + # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: + # one cheap probe before the agent is rebuilt so a 429'd pin gets + # repinned without burning planner/classifier/title calls first. + # See ``stream_new_chat`` for the full rationale on the speculative + # parallel build pattern below. + preflight_needed = ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ) + preflight_task: asyncio.Task[None] | None = None + _t_preflight = 0.0 + if preflight_needed: + _t_preflight = time.perf_counter() + preflight_task = asyncio.create_task( + _preflight_llm(llm), + name=f"auto_pin_preflight_resume:{llm_config_id}", + ) + _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) @@ -2866,8 +4129,13 @@ async def stream_resume_chat( from app.config import config as _app_config _t0 = time.perf_counter() - if _app_config.MULTI_AGENT_CHAT_ENABLED: - agent = await create_registry_deep_agent( + agent_factory = ( + create_registry_deep_agent + if _app_config.MULTI_AGENT_CHAT_ENABLED + else create_surfsense_deep_agent + ) + agent_build_task = asyncio.create_task( + agent_factory( llm=llm, search_space_id=search_space_id, db_session=session, @@ -2880,22 +4148,99 @@ async def stream_resume_chat( thread_visibility=visibility, filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, - ) - else: - agent = await create_surfsense_deep_agent( - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - ) + ), + name="agent_build:stream_resume", + ) + + agent: Any = None + if preflight_task is not None: + try: + await preflight_task + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + # Same session-safety rationale as ``stream_new_chat``. + await _settle_speculative_agent_build(agent_build_task) + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + agent = await agent_factory( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + ) + + if agent is None: + agent = await agent_build_task _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 ) @@ -2937,34 +4282,174 @@ async def stream_resume_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) + + # Pre-write a fresh assistant row for this resume turn. The + # original (interrupted) ``stream_new_chat`` invocation already + # persisted its own assistant row anchored to a different + # ``turn_id``; resume allocates a new ``turn_id`` (above) so we + # need a separate row keyed on the same ``(thread_id, turn_id, + # ASSISTANT)`` invariant. Idempotent against migration 141's + # partial unique index — recovers existing id on retry. + from app.tasks.chat.content_builder import AssistantContentBuilder + from app.tasks.chat.persistence import persist_assistant_shell + + assistant_message_id = await persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + ) + if assistant_message_id is None: + yield _emit_stream_error( + message=( + "We couldn't initialize the assistant message. Please try again." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + # Emit canonical assistant message id BEFORE any LLM streaming + # so the FE can rename ``pendingInterrupt.assistantMsgId`` to + # ``msg-{assistant_message_id}`` immediately. Resume does NOT + # emit ``data-user-message-id`` because the user row is from + # the original interrupted turn (different ``turn_id``) and is + # never re-persisted here. See B5 in the + # ``sse-based_message_id_handshake`` plan. + yield streaming_service.format_data( + "assistant-message-id", + {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, + ) + + stream_result.assistant_message_id = assistant_message_id + stream_result.content_builder = AssistantContentBuilder() + + # Resume path doesn't carry new ``mentioned_document_ids`` — + # those are seeded in the original turn. We still pass a + # context so future middleware extensions (Phase 2) can rely on + # ``runtime.context`` always being populated. + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + request_id=request_id, + turn_id=stream_result.turn_id, + ) _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume={"decisions": decisions}), - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking-resume", - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=Command(resume={"decisions": decisions}), + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, + content_builder=stream_result.content_builder, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True + yield sse + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) ) - _first_event_logged = True - yield sse + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + # Ensure the same-request recovery retry does not trip the + # BusyMutex lock retained by the failed attempt. + end_turn(str(chat_id)) + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error: + raise stream_exc + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_resume] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, @@ -2973,9 +4458,10 @@ async def stream_resume_chat( if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -2986,6 +4472,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -2995,28 +4482,23 @@ async def stream_resume_chat( yield streaming_service.format_done() return - # Finalize premium quota for resume path + # Finalize premium credit debit for resume path with the actual + # provider cost reported by LiteLLM (sum of cost across all + # calls in the turn). if _resume_premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=actual_premium_tokens, - reserved_tokens=_resume_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_resume_premium_reserved_micros, ) + _resume_premium_request_id = None + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", @@ -3026,9 +4508,10 @@ async def stream_resume_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", + "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3039,17 +4522,12 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except BusyError as e: - _busy_error_raised = True - yield streaming_service.format_error(str(e)) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -3057,18 +4535,55 @@ async def stream_resume_chat( except Exception as e: import traceback + # ``BusyError`` fires before the agent acquires the lock; the + # cleanup path must skip lock release to avoid freeing the + # in-flight caller's lock. Classification is handled below. + if isinstance(e, BusyError): + _busy_error_raised = True + + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + error_extra, + ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + extra=error_extra, + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized - if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: + if ( + _resume_premium_request_id + and _resume_premium_reserved_micros > 0 + and user_id + ): try: from app.services.token_quota_service import TokenQuotaService @@ -3076,9 +4591,9 @@ async def stream_resume_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_resume_premium_reserved, + reserved_micros=_resume_premium_reserved_micros, ) - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s (resume)", user_id @@ -3102,15 +4617,73 @@ async def stream_resume_chat( with contextlib.suppress(Exception): await session.close() + # Server-side assistant-message + token_usage finalization for + # the resume flow. The original user message was persisted by + # the original (interrupted) ``stream_new_chat`` invocation; + # the resume's own ``persist_assistant_shell`` write lives at + # the new ``turn_id`` above. This finalize updates that row + # with the rich ContentPart[] from the builder and writes + # token_usage idempotently via migration 142's partial + # unique index. Best-effort, never raises. + if ( + stream_result + and stream_result.turn_id + and stream_result.assistant_message_id + ): + from app.tasks.chat.persistence import finalize_assistant_turn + + builder_stats: dict[str, int] | None = None + if stream_result.content_builder is not None: + stream_result.content_builder.mark_interrupted() + builder_stats = stream_result.content_builder.stats() + content_payload = stream_result.content_builder.snapshot() + else: + content_payload = [ + { + "type": "text", + "text": stream_result.accumulated_text or "", + } + ] + + if builder_stats is not None: + _perf_log.info( + "[stream_resume] finalize_payload chat_id=%s " + "message_id=%s parts=%d bytes=%d text=%d " + "reasoning=%d tool_calls=%d " + "tool_calls_completed=%d tool_calls_aborted=%d " + "thinking_step_parts=%d step_separators=%d", + chat_id, + stream_result.assistant_message_id, + builder_stats["parts"], + builder_stats["bytes"], + builder_stats["text"], + builder_stats["reasoning"], + builder_stats["tool_calls"], + builder_stats["tool_calls_completed"], + builder_stats["tool_calls_aborted"], + builder_stats["thinking_step_parts"], + builder_stats["step_separators"], + ) + + await finalize_assistant_turn( + message_id=stream_result.assistant_message_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + turn_id=stream_result.turn_id, + content=content_payload, + accumulator=accumulator, + ) + # Release the lock from the original interrupted turn or any # re-interrupt/bailout. Skip on ``BusyError`` (lock not held here). if not _busy_error_raised: with contextlib.suppress(Exception): - if _release_busy_lock(str(chat_id)): - _perf_log.info( - "[stream_resume] released stale busy lock (chat_id=%s)", - chat_id, - ) + end_turn(str(chat_id)) + _perf_log.info( + "[stream_resume] end_turn cleanup (chat_id=%s)", + chat_id, + ) agent = llm = connector_service = None stream_result = None diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index 6912ffe5a..3c9f27303 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from .base import ( check_duplicate_document_by_hash, @@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]] HEARTBEAT_INTERVAL_SECONDS = 30 +def _format_calendar_event_to_markdown(event: dict) -> str: + return GoogleCalendarConnector.format_event_to_markdown(None, event) + + def _build_connector_doc( event: dict, event_markdown: str, @@ -150,7 +152,14 @@ async def index_google_calendar_events( ) return 0, 0, f"Connector with ID {connector_id} not found" - # ── Credential building ─────────────────────────────────────── + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) + calendar_client = None + composio_service = None + connected_account_id = None + + # ── Credential/client building ──────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -161,7 +170,7 @@ async def index_google_calendar_events( {"error_type": "MissingComposioAccount"}, ) return 0, 0, "Composio connected_account_id not found" - credentials = build_composio_credentials(connected_account_id) + composio_service = ComposioService() else: config_data = connector.config @@ -229,12 +238,13 @@ async def index_google_calendar_events( {"stage": "client_initialization"}, ) - calendar_client = GoogleCalendarConnector( - credentials=credentials, - session=session, - user_id=user_id, - connector_id=connector_id, - ) + if not is_composio_connector: + calendar_client = GoogleCalendarConnector( + credentials=credentials, + session=session, + user_id=user_id, + connector_id=connector_id, + ) # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": @@ -300,9 +310,26 @@ async def index_google_calendar_events( ) try: - events, error = await calendar_client.get_all_primary_calendar_events( - start_date=start_date_str, end_date=end_date_str - ) + if is_composio_connector: + start_dt = parse_date_flexible(start_date_str).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + end_dt = parse_date_flexible(end_date_str).replace( + hour=23, minute=59, second=59, microsecond=0 + ) + events, error = await composio_service.get_calendar_events( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + time_min=start_dt.isoformat(), + time_max=end_dt.isoformat(), + max_results=250, + ) + if not events and not error: + error = "No events found in the specified date range." + else: + events, error = await calendar_client.get_all_primary_calendar_events( + start_date=start_date_str, end_date=end_date_str + ) if error: if "No events found" in error: @@ -381,7 +408,7 @@ async def index_google_calendar_events( documents_skipped += 1 continue - event_markdown = calendar_client.format_event_to_markdown(event) + event_markdown = _format_calendar_event_to_markdown(event) if not event_markdown.strip(): logger.warning(f"Skipping event with no content: {event_summary}") documents_skipped += 1 diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 21cdbd29f..686f13d9e 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -9,6 +9,8 @@ import asyncio import logging import time from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any from sqlalchemy import String, cast, select from sqlalchemy.exc import SQLAlchemyError @@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService @@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import ( get_connector_by_id, update_connector_last_indexed, ) -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES ACCEPTED_DRIVE_CONNECTOR_TYPES = { SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, @@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30 logger = logging.getLogger(__name__) +class ComposioDriveClient: + """Google Drive client facade backed by Composio tool execution. + + Composio-managed OAuth connections can execute tools without exposing raw + OAuth tokens through connected account state. + """ + + def __init__( + self, + session: AsyncSession, + connector_id: int, + connected_account_id: str, + entity_id: str, + ): + self.session = session + self.connector_id = connector_id + self.connected_account_id = connected_account_id + self.entity_id = entity_id + self.composio = ComposioService() + + async def list_files( + self, + query: str = "", + fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)", + page_size: int = 100, + page_token: str | None = None, + ) -> tuple[list[dict[str, Any]], str | None, str | None]: + params: dict[str, Any] = { + "page_size": min(page_size, 100), + "fields": fields, + } + if query: + params["q"] = query + if page_token: + params["page_token"] = page_token + + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_LIST_FILES", + params=params, + entity_id=self.entity_id, + ) + if not result.get("success"): + return [], None, result.get("error", "Unknown error") + + data = result.get("data", {}) + files = [] + next_token = None + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + files = inner_data.get("files", []) + next_token = inner_data.get("nextPageToken") or inner_data.get( + "next_page_token" + ) + elif isinstance(data, list): + files = data + + return files, next_token, None + + async def get_file_metadata( + self, file_id: str, fields: str = "*" + ) -> tuple[dict[str, Any] | None, str | None]: + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_GET_FILE_METADATA", + params={"file_id": file_id, "fields": fields}, + entity_id=self.entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + data = result.get("data", {}) + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + return inner_data, None + + return None, "Could not extract metadata from Composio response" + + async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]: + return await self._download_file_content(file_id) + + async def download_file_to_disk( + self, + file_id: str, + dest_path: str, + chunksize: int = 5 * 1024 * 1024, + ) -> str | None: + del chunksize + content, error = await self.download_file(file_id) + if error: + return error + if content is None: + return "No content returned from Composio" + Path(dest_path).write_bytes(content) + return None + + async def export_google_file( + self, file_id: str, mime_type: str + ) -> tuple[bytes | None, str | None]: + return await self._download_file_content(file_id, mime_type=mime_type) + + async def _download_file_content( + self, file_id: str, mime_type: str | None = None + ) -> tuple[bytes | None, str | None]: + params: dict[str, Any] = {"file_id": file_id} + if mime_type: + params["mime_type"] = mime_type + + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_DOWNLOAD_FILE", + params=params, + entity_id=self.entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + return self._read_download_result(result.get("data")) + + def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]: + if isinstance(data, bytes): + return data, None + + file_path: str | None = None + if isinstance(data, str): + file_path = data + elif isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + for key in ("file_path", "downloaded_file_content", "path", "uri"): + value = inner_data.get(key) + if isinstance(value, str): + file_path = value + break + if isinstance(value, dict): + nested = ( + value.get("file_path") + or value.get("downloaded_file_content") + or value.get("path") + or value.get("uri") + or value.get("s3url") + ) + if isinstance(nested, str): + file_path = nested + break + + if not file_path: + return None, "No file path/content returned from Composio" + + if file_path.startswith(("http://", "https://")): + try: + import urllib.request + + with urllib.request.urlopen(file_path, timeout=60) as response: + return response.read(), None + except Exception as e: + return None, f"Failed to download Composio file URL: {e!s}" + + path_obj = Path(file_path) + if path_obj.is_absolute() or ".composio" in str(path_obj): + if not path_obj.exists(): + return None, f"File not found at path: {file_path}" + return path_obj.read_bytes(), None + + try: + import base64 + + return base64.b64decode(file_path), None + except Exception: + return file_path.encode("utf-8"), None + + +def _build_drive_client_for_connector( + session: AsyncSession, + connector_id: int, + connector: object, + user_id: str, +) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]: + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + connected_account_id = connector.config.get("composio_connected_account_id") + if not connected_account_id: + return None, ( + f"Composio connected_account_id not found for connector {connector_id}" + ) + return ( + ComposioDriveClient( + session, + connector_id, + connected_account_id, + entity_id=f"surfsense_{user_id}", + ), + None, + ) + + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and not config.SECRET_KEY: + return None, "SECRET_KEY not configured but credentials are marked as encrypted" + + return GoogleDriveClient(session, connector_id), None + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -927,34 +1130,17 @@ async def index_google_drive_files( {"stage": "client_initialization"}, ) - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, 0, error_msg, 0 - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - "SECRET_KEY not configured but credentials are encrypted", - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - 0, - ) + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + await task_logger.log_task_failure( + log_entry, + client_error or "Failed to initialize Google Drive client", + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, 0, client_error, 0 connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -963,10 +1149,6 @@ async def index_google_drive_files( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - if not folder_id: error_msg = "folder_id is required for Google Drive indexing" await task_logger.log_task_failure( @@ -979,8 +1161,14 @@ async def index_google_drive_files( folder_tokens = connector.config.get("folder_tokens", {}) start_page_token = folder_tokens.get(target_folder_id) + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) can_use_delta = ( - use_delta_sync and start_page_token and connector.last_indexed_at + not is_composio_connector + and use_delta_sync + and start_page_token + and connector.last_indexed_at ) documents_unsupported = 0 @@ -1051,7 +1239,16 @@ async def index_google_drive_files( ) if documents_indexed > 0 or can_use_delta: - new_token, token_error = await get_start_page_token(drive_client) + if isinstance(drive_client, ComposioDriveClient): + ( + new_token, + token_error, + ) = await drive_client.composio.get_drive_start_page_token( + drive_client.connected_account_id, + drive_client.entity_id, + ) + else: + new_token, token_error = await get_start_page_token(drive_client) if new_token and not token_error: await session.refresh(connector) if "folder_tokens" not in connector.config: @@ -1137,32 +1334,17 @@ async def index_google_drive_single_file( ) return 0, error_msg - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, error_msg - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - "SECRET_KEY not configured but credentials are encrypted", - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - ) + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + await task_logger.log_task_failure( + log_entry, + client_error or "Failed to initialize Google Drive client", + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, client_error connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -1171,10 +1353,6 @@ async def index_google_drive_single_file( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - file, error = await get_file_by_id(drive_client, file_id) if error or not file: error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}" @@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files( ) return 0, 0, [error_msg] - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, 0, [error_msg] - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - error_msg = ( - "SECRET_KEY not configured but credentials are marked as encrypted" - ) - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return 0, 0, [error_msg] + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + error_msg = client_error or "Failed to initialize Google Drive client" + await task_logger.log_task_failure( + log_entry, + error_msg, + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, 0, [error_msg] connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - indexed, skipped, unsupported, errors = await _index_selected_files( drive_client, session, diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index ef226087b..6697c0eb1 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from .base import ( calculate_date_range, @@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]] HEARTBEAT_INTERVAL_SECONDS = 30 +def _normalize_composio_gmail_message(message: dict) -> dict: + if message.get("payload"): + return message + + headers = [] + header_values = { + "Subject": message.get("subject"), + "From": message.get("from") or message.get("sender"), + "To": message.get("to") or message.get("recipient"), + "Date": message.get("date"), + } + for name, value in header_values.items(): + if value: + headers.append({"name": name, "value": value}) + + return { + **message, + "id": message.get("id") + or message.get("message_id") + or message.get("messageId"), + "threadId": message.get("threadId") or message.get("thread_id"), + "payload": {"headers": headers}, + "snippet": message.get("snippet", ""), + "messageText": message.get("messageText") or message.get("body") or "", + } + + +def _format_gmail_message_to_markdown(message: dict) -> str: + headers = { + header.get("name", "").lower(): header.get("value", "") + for header in message.get("payload", {}).get("headers", []) + if isinstance(header, dict) + } + subject = headers.get("subject", "No Subject") + from_email = headers.get("from", "Unknown Sender") + to_email = headers.get("to", "Unknown Recipient") + date_str = headers.get("date", "Unknown Date") + message_text = ( + message.get("messageText") + or message.get("body") + or message.get("text") + or message.get("snippet", "") + ) + + return ( + f"# {subject}\n\n" + f"**From:** {from_email}\n" + f"**To:** {to_email}\n" + f"**Date:** {date_str}\n\n" + f"## Message Content\n\n{message_text}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {message.get('id', 'Unknown')}\n" + f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n" + ) + + def _build_connector_doc( message: dict, markdown_content: str, @@ -162,7 +216,14 @@ async def index_google_gmail_messages( ) return 0, 0, error_msg - # ── Credential building ─────────────────────────────────────── + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) + gmail_connector = None + composio_service = None + connected_account_id = None + + # ── Credential/client building ──────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -173,7 +234,7 @@ async def index_google_gmail_messages( {"error_type": "MissingComposioAccount"}, ) return 0, 0, "Composio connected_account_id not found" - credentials = build_composio_credentials(connected_account_id) + composio_service = ComposioService() else: config_data = connector.config @@ -241,9 +302,10 @@ async def index_google_gmail_messages( {"stage": "client_initialization"}, ) - gmail_connector = GoogleGmailConnector( - credentials, session, user_id, connector_id - ) + if not is_composio_connector: + gmail_connector = GoogleGmailConnector( + credentials, session, user_id, connector_id + ) calculated_start_date, calculated_end_date = calculate_date_range( connector, start_date, end_date, default_days_back=365 @@ -254,11 +316,60 @@ async def index_google_gmail_messages( f"Fetching emails for connector {connector_id} " f"from {calculated_start_date} to {calculated_end_date}" ) - messages, error = await gmail_connector.get_recent_messages( - max_results=max_messages, - start_date=calculated_start_date, - end_date=calculated_end_date, - ) + if is_composio_connector: + query_parts = [] + if calculated_start_date: + query_parts.append(f"after:{calculated_start_date.replace('-', '/')}") + if calculated_end_date: + query_parts.append(f"before:{calculated_end_date.replace('-', '/')}") + query = " ".join(query_parts) + + messages = [] + page_token = None + error = None + while len(messages) < max_messages: + page_size = min(50, max_messages - len(messages)) + ( + page_messages, + page_token, + _estimate, + page_error, + ) = await composio_service.get_gmail_messages( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=page_size, + page_token=page_token, + ) + if page_error: + error = page_error + break + for page_message in page_messages: + message_id = ( + page_message.get("id") + or page_message.get("message_id") + or page_message.get("messageId") + ) + if message_id: + ( + detail, + detail_error, + ) = await composio_service.get_gmail_message_detail( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if not detail_error and isinstance(detail, dict): + page_message = detail + messages.append(_normalize_composio_gmail_message(page_message)) + if not page_token: + break + else: + messages, error = await gmail_connector.get_recent_messages( + max_results=max_messages, + start_date=calculated_start_date, + end_date=calculated_end_date, + ) if error: error_message = error @@ -326,7 +437,12 @@ async def index_google_gmail_messages( documents_skipped += 1 continue - markdown_content = gmail_connector.format_message_to_markdown(message) + if is_composio_connector: + markdown_content = _format_gmail_message_to_markdown(message) + else: + markdown_content = gmail_connector.format_message_to_markdown( + message + ) if not markdown_content.strip(): logger.warning(f"Skipping message with no content: {message_id}") documents_skipped += 1 diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 131627386..da8c4b7d1 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.19" +version = "0.0.21" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ @@ -71,11 +71,11 @@ dependencies = [ "langchain>=1.2.13", "langgraph>=1.1.3", "langchain-community>=0.4.1", - "deepagents>=0.4.12", "stripe>=15.0.0", "azure-ai-documentintelligence>=1.0.2", - "litellm>=1.83.4", + "litellm>=1.83.7", "langchain-litellm>=0.6.4", + "deepagents>=0.4.12,<0.5", ] [dependency-groups] diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py new file mode 100644 index 000000000..a49d4eab2 --- /dev/null +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -0,0 +1,558 @@ +"""End-to-end smoke test for vision / image config wiring. + +Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and +exercises every chat / vision / image-generation config + the OpenRouter +dynamic catalog. For each config the script: + +1. Reports the resolver classification (catalog-allow vs strict-block). +2. Optionally fires a tiny live API call against the provider: + - Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt + ``"reply with one word: ok"``. + - Vision configs: same, against the dedicated vision router pool. + - Image-gen configs: ``litellm.aimage_generation`` with a single tiny + prompt and ``n=1``. + - OpenRouter integration: samples one chat, one vision, one image-gen + model from the dynamically fetched catalog. + +Usage:: + + python -m scripts.verify_chat_image_capability # capability + connectivity + python -m scripts.verify_chat_image_capability --no-live # capability resolver only + +The script is meant to be runnable from the repository root or from +``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the +end so it's usable as a CI smoke check too. + +Live-mode caveat: each successful call costs a small amount of provider +credit (a few tokens or one tiny generated image per config). The +default size for image generation is ``1024x1024`` because Azure +GPT-image deployments reject smaller sizes; OpenRouter image-gen models +generally accept the same size. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Any + +# Bootstrap the surfsense_backend package on sys.path so the script runs +# from the repo root or from `surfsense_backend/` interchangeably. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BACKEND_ROOT = os.path.dirname(_HERE) +if _BACKEND_ROOT not in sys.path: + sys.path.insert(0, _BACKEND_ROOT) + +import litellm # noqa: E402 + +from app.config import config # noqa: E402 +from app.services.openrouter_integration_service import ( # noqa: E402 + _OPENROUTER_DYNAMIC_MARKER, + OpenRouterIntegrationService, +) +from app.services.provider_api_base import resolve_api_base # noqa: E402 +from app.services.provider_capabilities import ( # noqa: E402 + derive_supports_image_input, + is_known_text_only_chat_model, +) + +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", +) +# Quiet down LiteLLM's verbose router/cost logs so the script output is +# scannable. +logging.getLogger("LiteLLM").setLevel(logging.ERROR) +logging.getLogger("litellm").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) + +# 1x1 transparent PNG — used as the cheapest possible vision payload. +_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" +_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}" + + +# --------------------------------------------------------------------------- +# Result accounting +# --------------------------------------------------------------------------- + + +@dataclass +class ProbeResult: + label: str + surface: str + config_id: int | str + capability_ok: bool | None = None + capability_note: str = "" + live_ok: bool | None = None + live_note: str = "" + duration_s: float = 0.0 + + +@dataclass +class Report: + results: list[ProbeResult] = field(default_factory=list) + + def add(self, r: ProbeResult) -> None: + self.results.append(r) + + def render(self) -> int: + passed = failed = skipped = 0 + print() + print("=" * 92) + print( + f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes" + ) + print("-" * 92) + for r in self.results: + + def _flag(value: bool | None) -> str: + if value is None: + return "skip" + return "ok" if value else "fail" + + cap = _flag(r.capability_ok) + live = _flag(r.live_ok) + if r.capability_ok is False or r.live_ok is False: + failed += 1 + elif r.capability_ok is None and r.live_ok is None: + skipped += 1 + else: + passed += 1 + print( + f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} " + f"{r.duration_s:>5.2f}s {r.label}" + ) + if r.capability_note: + print(f" cap: {r.capability_note}") + if r.live_note: + print(f" live: {r.live_note}") + print("-" * 92) + print( + f"Total: {passed} ok / {failed} fail / {skipped} skip " + f"(of {len(self.results)} probes)" + ) + print("=" * 92) + return failed + + +# --------------------------------------------------------------------------- +# Capability probes (no network) +# --------------------------------------------------------------------------- + + +def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: + """For chat configs the catalog flag is *expected* True (vision-capable + pool). The probe reports both the resolver value and the strict + safety-net value to surface any drift between them.""" + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + cap = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + block = is_known_text_only_chat_model( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + note = f"derive={cap} strict_block={block}" + if not cap and not block: + # Resolver said False but strict gate is also False — that means + # OR modalities published [text] explicitly. Surface it. + note += " (OR modality says text-only)" + # We accept a True derive *or* (False derive AND False block) as + # 'capability ok' — either way, the streaming task will flow through. + ok = cap or not block + return ok, note + + +def _build_chat_model_string(cfg: dict) -> str: + if cfg.get("custom_provider"): + return f"{cfg['custom_provider']}/{cfg['model_name']}" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + return f"{prefix}/{cfg['model_name']}" + + +# --------------------------------------------------------------------------- +# Live probes (network calls) +# --------------------------------------------------------------------------- + + +async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: + """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" + model_string = _build_chat_model_string(cfg) + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=model_string.split("/", 1)[0], + config_api_base=cfg.get("api_base") or None, + ) + kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "reply with one word: ok"}, + { + "type": "image_url", + "image_url": {"url": _TINY_PNG_DATA_URL}, + }, + ], + } + ], + "max_tokens": 16, + "timeout": 60, + } + if api_base: + kwargs["api_base"] = api_base + if cfg.get("litellm_params"): + # Strip pricing keys — they're tracking-only and confuse some + # provider validators (e.g. azure/openai reject unknown kwargs + # in strict mode). + merged = { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + kwargs.update(merged) + try: + resp = await litellm.acompletion(**kwargs) + except Exception as exc: + return False, f"{type(exc).__name__}: {exc}" + text = resp.choices[0].message.content if resp.choices else "" + return True, f"got reply ({(text or '').strip()[:40]!r})" + + +# Gemini image models occasionally return zero-length ``data`` for the +# minimal "red dot on white" prompt (provider-side safety / empty-output +# quirk reproducible against ``google/gemini-2.5-flash-image`` even when +# the request itself succeeds). Use a more naturalistic prompt and +# retry once with a different one before giving up. +_IMAGE_GEN_PROMPTS: tuple[str, ...] = ( + "A simple icon of a coffee cup, flat illustration", + "A small green leaf on a white background", +) + + +async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: + """Generate one tiny image to verify the deployment is reachable.""" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + if cfg.get("custom_provider"): + prefix = cfg["custom_provider"] + else: + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + model_string = f"{prefix}/{cfg['model_name']}" + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=prefix, + config_api_base=cfg.get("api_base") or None, + ) + base_kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "n": 1, + "size": "1024x1024", + "timeout": 120, + } + if api_base: + base_kwargs["api_base"] = api_base + if cfg.get("api_version"): + base_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + base_kwargs.update( + { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + ) + + last_note = "" + for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1): + try: + resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs) + except Exception as exc: + last_note = f"{type(exc).__name__}: {exc}" + continue + data_count = len(getattr(resp, "data", None) or []) + if data_count > 0: + return True, ( + f"received {data_count} image(s) on attempt {attempt} " + f"(prompt={prompt!r})" + ) + last_note = ( + f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})" + ) + return False, last_note + + +# --------------------------------------------------------------------------- +# Probe drivers +# --------------------------------------------------------------------------- + + +def _is_or_dynamic(cfg: dict) -> bool: + return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER)) + + +async def probe_chat_configs(report: Report, *, live: bool) -> None: + print("\n[chat configs from global_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_LLM_CONFIGS: + # Skip OR dynamic entries here — handled in the OR section so + # the YAML / OR split stays clear in the report. + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="chat-yaml", + config_id=cfg.get("id"), + ) + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_vision_configs(report: Report, *, live: bool) -> None: + print("\n[vision configs from global_vision_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="vision", + config_id=cfg.get("id"), + ) + # For vision configs, capability is implied — they're in the + # dedicated vision pool. Run the same resolver to flag any + # surprise disagreement. + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_image_gen_configs(report: Report, *, live: bool) -> None: + print( + "\n[image generation configs from global_image_generation_configs (YAML-static)]" + ) + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="image-gen", + config_id=cfg.get("id"), + ) + # Image gen configs don't have a "supports_image_input" flag; + # the catalog tracks output, not input. Mark capability as None + # (skip) for the report. + if live: + t0 = time.perf_counter() + ok, note = await _live_image_gen_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: + """Sample one chat (vision-capable), one vision, one image-gen model + from the live OpenRouter catalogue. Doesn't iterate the full pool + (would be hundreds of probes); just validates the integration end- + to-end on a representative model from each surface.""" + print("\n[OpenRouter integration: sampled probes]") + settings = config.OPENROUTER_INTEGRATION_SETTINGS + if not settings: + report.add( + ProbeResult( + label="OpenRouter integration", + surface="openrouter", + config_id="settings", + capability_ok=None, + capability_note="openrouter_integration disabled in YAML — skipping", + live_ok=None, + ) + ) + return + + service = OpenRouterIntegrationService.get_instance() + or_chat = [ + c + for c in config.GLOBAL_LLM_CONFIGS + if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") + ] + or_vision = [ + c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" + ] + or_image_gen = [ + c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" + ] + + # Pick one representative per provider family per surface so a single + # broken vendor (e.g. Anthropic key revoked, Google quota exceeded) + # surfaces independently of the others. Each needle matches the + # OpenRouter ``model_name`` prefix; the first match wins. + def _pick_first(pool: list[dict], needle: str) -> dict | None: + for c in pool: + if (c.get("model_name") or "").lower().startswith(needle): + return c + return None + + chat_picks = [ + ("or-chat", _pick_first(or_chat, "openai/gpt-4o")), + ("or-chat", _pick_first(or_chat, "anthropic/claude")), + ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), + ] + vision_picks = [ + ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), + ("or-vision", _pick_first(or_vision, "anthropic/claude")), + ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), + ] + image_picks = [ + ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), + # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` + # / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match + # the actual prefix. + ("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")), + ] + + print( + f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" + ) + + for surface, picked in chat_picks + vision_picks + image_picks: + if not picked: + report.add( + ProbeResult( + label=f"", + surface=surface, + config_id="-", + capability_ok=None, + capability_note="no candidate found in OR catalog", + ) + ) + continue + runner = ( + _live_image_gen_call if surface == "or-image" else _live_chat_image_call + ) + result = ProbeResult( + label=str(picked.get("model_name")), + surface=surface, + config_id=picked.get("id"), + ) + if surface != "or-image": + cap_ok, cap_note = _probe_chat_capability(picked) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await runner(picked) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +async def main(args: argparse.Namespace) -> int: + print("Loaded global configs:") + print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") + print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") + print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") + print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") + + # Initialize the OpenRouter integration so the catalog is populated + # (this is what main.py does at startup). It's idempotent. + if config.OPENROUTER_INTEGRATION_SETTINGS: + try: + from app.config import initialize_openrouter_integration + + initialize_openrouter_integration() + except Exception as exc: + print(f" WARNING: OpenRouter integration init failed: {exc}") + + print( + f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}" + ) + + report = Report() + if not args.skip_chat: + await probe_chat_configs(report, live=args.live) + if not args.skip_vision: + await probe_vision_configs(report, live=args.live) + if not args.skip_image_gen: + await probe_image_gen_configs(report, live=args.live) + if not args.skip_openrouter: + await probe_openrouter_catalog(report, live=args.live) + + failed = report.render() + return 1 if failed else 0 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--no-live", + dest="live", + action="store_false", + help="Skip live API calls — capability resolver only.", + ) + parser.set_defaults(live=True) + parser.add_argument("--skip-chat", action="store_true") + parser.add_argument("--skip-vision", action="store_true") + parser.add_argument("--skip-image-gen", action="store_true") + parser.add_argument("--skip-openrouter", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + sys.exit(asyncio.run(main(args))) diff --git a/surfsense_backend/tests/integration/chat/__init__.py b/surfsense_backend/tests/integration/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py new file mode 100644 index 000000000..a5182a978 --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py @@ -0,0 +1,573 @@ +"""Integration tests for the cross-writer integration between the +streaming chat task and the legacy ``POST /threads/{id}/messages`` +(``append_message``) round-trip. + +Two scenarios anchor the contract introduced by the server-side +persistence rework: + +(a) **Tool-heavy turn streamed to completion.** + + Drives :class:`AssistantContentBuilder` with synthetic SSE events + that mirror what ``_stream_agent_events`` emits for a turn that + interleaves text, reasoning, a tool call (start/delta/available/ + output), and a final text block. Then runs + :func:`finalize_assistant_turn` and asserts: + + * ``new_chat_messages.content`` JSONB matches the + ``ContentPart[]`` shape the FE history loader expects, with full + ``args``/``argsText``/``result``/``langchainToolCallId`` for the + tool call. + * Exactly one ``token_usage`` row exists keyed on the assistant + ``message_id``. + +(b) **Stale FE ``appendMessage`` after server finalize.** + + Verifies the recovery branch of the ``append_message`` route now + returns the SERVER's authoritative ``ContentPart[]`` (not the FE's + stale payload) when the partial unique index from migration 141 + blocks the FE's INSERT, and that the ``ON CONFLICT DO NOTHING`` + clause from migration 142 stops the route from writing a duplicate + ``token_usage`` row. +""" + +from __future__ import annotations + +import json +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatMessageRole, + NewChatThread, + SearchSpace, + TokenUsage, + User, +) +from app.routes import new_chat_routes +from app.services.token_tracking_service import TurnTokenAccumulator +from app.tasks.chat import persistence as persistence_module +from app.tasks.chat.content_builder import AssistantContentBuilder +from app.tasks.chat.persistence import ( + finalize_assistant_turn, + persist_assistant_shell, +) + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_thread( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +) -> NewChatThread: + thread = NewChatThread( + title="Test Chat", + search_space_id=db_search_space.id, + created_by_id=db_user.id, + visibility=ChatVisibility.PRIVATE, + ) + db_session.add(thread) + await db_session.flush() + return thread + + +@pytest.fixture +def patched_shielded_session(monkeypatch, db_session: AsyncSession): + """Route persistence helpers to the test's savepoint-bound session. + + Mirrors the helper from ``test_persistence.py`` so the helpers' + internal ``ws.commit()`` / ``ws.rollback()`` resolve to SAVEPOINT + operations on the test transaction instead of touching real + autocommit boundaries. + """ + + @asynccontextmanager + async def _fake_shielded_session(): + yield db_session + + monkeypatch.setattr( + persistence_module, + "shielded_async_session", + _fake_shielded_session, + ) + return db_session + + +@pytest.fixture +def bypass_permission_checks(monkeypatch): + """Replace RBAC + thread access checks with no-ops. + + The append_message route under test calls ``check_permission`` and + ``check_thread_access``; those rely on a SearchSpaceMembership row + that the existing integration fixtures don't create. The contract + we want to verify here is the ``IntegrityError`` -> recovery branch, + not the RBAC plumbing — so stub them. + """ + + async def _allow(*_args, **_kwargs): + return True + + monkeypatch.setattr(new_chat_routes, "check_permission", _allow) + monkeypatch.setattr(new_chat_routes, "check_thread_access", _allow) + return None + + +class _FakeRequest: + """Minimal Request stand-in used by ``append_message``. + + The route only calls ``await request.json()`` — keep the surface + area tight so this doesn't accidentally hide future signature + changes that we *would* want to break the test. + """ + + def __init__(self, body: dict): + self._body = body + + async def json(self) -> dict: + return self._body + + +def _build_tool_heavy_content() -> list[dict]: + """Drive ``AssistantContentBuilder`` through a tool-heavy turn. + + Produces the same ``ContentPart[]`` shape the streaming layer would + persist if ``_stream_agent_events`` ran a turn with: opening + reasoning -> text -> tool call (input start/delta/available/output) + -> closing text. Centralised here so the (a) and (b) scenarios use + the same authoritative payload. + """ + builder = AssistantContentBuilder() + + builder.on_reasoning_start("r1") + builder.on_reasoning_delta("r1", "Let me look up ") + builder.on_reasoning_delta("r1", "the file listing.") + builder.on_reasoning_end("r1") + + builder.on_text_start("t1") + builder.on_text_delta("t1", "Sure, listing files in ") + builder.on_text_delta("t1", "/.") + builder.on_text_end("t1") + + builder.on_tool_input_start( + "tool_call_ui_1", + tool_name="ls", + langchain_tool_call_id="lc_call_xyz", + ) + builder.on_tool_input_delta("tool_call_ui_1", '{"path"') + builder.on_tool_input_delta("tool_call_ui_1", ': "/"}') + builder.on_tool_input_available( + "tool_call_ui_1", + tool_name="ls", + args={"path": "/"}, + langchain_tool_call_id="lc_call_xyz", + ) + builder.on_tool_output_available( + "tool_call_ui_1", + output={"files": ["a.txt", "b.txt"]}, + langchain_tool_call_id="lc_call_xyz", + ) + + builder.on_text_start("t2") + builder.on_text_delta("t2", "Found 2 files: a.txt and b.txt.") + builder.on_text_end("t2") + + return builder.snapshot() + + +def _accumulator_with_one_call() -> TurnTokenAccumulator: + acc = TurnTokenAccumulator() + acc.add( + model="gpt-4o-mini", + prompt_tokens=200, + completion_tokens=80, + total_tokens=280, + cost_micros=22222, + ) + return acc + + +# --------------------------------------------------------------------------- +# (a) Tool-heavy stream finalize +# --------------------------------------------------------------------------- + + +class TestToolHeavyTurnFinalize: + async def test_full_tool_call_persisted_and_one_token_usage_row( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + """End-to-end seam: builder snapshot -> finalize -> DB row. + + Matches the production flow's *content* invariant: whatever + ``AssistantContentBuilder.snapshot()`` produces is what the + streaming layer hands to ``finalize_assistant_turn``, so this + test catches any drift between the JSONB shape the builder + emits and the one the FE history loader expects. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:tool_heavy" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + snapshot = _build_tool_heavy_content() + # Sanity-check the snapshot before we hand it to the DB so a + # builder regression surfaces here, not deep inside an opaque + # JSONB diff. + assert any(p.get("type") == "reasoning" for p in snapshot) + text_parts = [p for p in snapshot if p.get("type") == "text"] + assert len(text_parts) == 2 + tool_parts = [p for p in snapshot if p.get("type") == "tool-call"] + assert len(tool_parts) == 1 + tool_part = tool_parts[0] + assert tool_part["toolCallId"] == "tool_call_ui_1" + assert tool_part["toolName"] == "ls" + assert tool_part["args"] == {"path": "/"} + # ``argsText`` ends up as the pretty-printed final args (the + # ``tool-input-available`` event replaces the streamed deltas + # with ``json.dumps(args, indent=2)`` to match the FE's + # post-stream rendering). + assert tool_part["argsText"] == '{\n "path": "/"\n}' + assert tool_part["result"] == {"files": ["a.txt", "b.txt"]} + # ``langchainToolCallId`` is the agent-side correlation id used + # by the regenerate path; a missing one breaks + # edit-from-tool-call later. + assert tool_part["langchainToolCallId"] == "lc_call_xyz" + + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=snapshot, + accumulator=_accumulator_with_one_call(), + ) + + # ``content`` must round-trip byte-for-byte through the JSONB + # column. SQLAlchemy doesn't auto-refresh the row that survived + # the savepoint commit, so refresh explicitly. + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + + # The history loader reads ``content`` straight into the FE's + # parts array, so a strict equality comparison is the right + # invariant here. + assert row.content == snapshot + # Tool-call parts must JSON-serialise cleanly — nothing in + # ``args`` / ``argsText`` / ``result`` should accidentally be a + # non-JSON-safe value (datetime, set, custom class). + assert json.dumps(row.content) + + usage_count = ( + await db_session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage_count == 1 + + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.usage_type == "chat" + assert usage.prompt_tokens == 200 + assert usage.completion_tokens == 80 + assert usage.total_tokens == 280 + assert usage.cost_micros == 22222 + assert usage.thread_id == thread_id + assert usage.search_space_id == search_space_id + + +# --------------------------------------------------------------------------- +# (b) FE appendMessage after server finalize +# --------------------------------------------------------------------------- + + +class TestAppendMessageRecoveryAfterFinalize: + async def test_returns_server_content_and_does_not_duplicate_token_usage( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + bypass_permission_checks, + ): + """FE's stale ``appendMessage`` after server finalize. + + The frontend used to be the authoritative writer for assistant + ``content``. Now the server is. When the legacy FE round-trip + fires *after* the server has already finalized: + + * the route's INSERT trips the (thread_id, turn_id, role) + partial unique index from migration 141, + * the recovery branch fetches the existing row and returns + *its* ``content`` — discarding the FE payload — so the + history loader reads the rich server payload (full tool + args, argsText, langchainToolCallId, etc.) on next page + reload, + * the route's optional ``token_usage`` insert is keyed on the + partial unique index from migration 142 so it silently + no-ops if the server already wrote one. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:fe_late_append" + + # Step 1: server stream completes. Server-built rich content is + # finalized. + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + server_content = _build_tool_heavy_content() + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=server_content, + accumulator=_accumulator_with_one_call(), + ) + + # Step 2: simulate the legacy FE ``appendMessage`` round-trip + # arriving with stale, lossy content (missing tool args, etc.) + # plus a ``token_usage`` body. + fe_stale_content = [ + {"type": "text", "text": "Found 2 files: a.txt and b.txt."}, + ] + fe_request_body = { + "role": "assistant", + "content": fe_stale_content, + "turn_id": turn_id, + "token_usage": { + "prompt_tokens": 999, + "completion_tokens": 999, + "total_tokens": 1998, + "cost_micros": 88888, + "usage": {"any": "thing"}, + "call_details": {"calls": []}, + }, + } + request = _FakeRequest(fe_request_body) + + # ``db_user`` is bound to ``db_session``. The route's + # IntegrityError branch calls ``session.rollback()``, which + # expires every ORM row attached to the session including this + # user — historically causing ``user.id`` to lazy-load + # out-of-greenlet and crash the request with ``MissingGreenlet`` + # (observed in production logs at /api/v1/threads/531/messages + # 2026-05-04). The route now captures ``user.id`` to a primitive + # UUID at the top of the handler, so the rollback can't reach + # it. Pass the *attached* user here on purpose — that's the + # production scenario, and this test is the regression guard + # against that bug returning. + response = await new_chat_routes.append_message( + thread_id=thread_id, + request=request, + session=db_session, + user=db_user, + ) + + # Response must echo the SERVER's rich payload, not the FE's + # stale snapshot. This is the user-visible part of the + # contract: history reload + ThreadHistoryAdapter.append both + # read from the same authoritative source. + assert response.id == msg_id + assert response.role == NewChatMessageRole.ASSISTANT + assert response.turn_id == turn_id + assert response.content == server_content + assert response.content != fe_stale_content + + # The on-disk row must agree with the response. + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == server_content + + # ``token_usage``: exactly one row, with the *server's* values + # (the FE's much larger token counts must not have overwritten + # them). + usage_count = ( + await db_session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage_count == 1 + + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.cost_micros == 22222 # Server's value, not 88888 + assert usage.total_tokens == 280 # Server's value, not 1998 + + async def test_legacy_fe_first_appendmessage_then_server_no_dupe( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + bypass_permission_checks, + ): + """Inverse race: legacy FE writes first, server finalize second. + + Some clients still post ``appendMessage`` before the streaming + ``finally`` runs. The contract is symmetric: whichever writer + loses the (thread_id, turn_id, role) race silently lets the + winner keep its content. In particular the *server's* + finalize must NOT raise — it must look up the existing row and + UPDATE its content with the server-built payload (which is + always richer/more authoritative than whatever the FE + snapshot held). + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:fe_first" + + # Step 1: legacy FE appendMessage lands first. No prior shell + # row exists; the route does the INSERT itself. + fe_request_body = { + "role": "assistant", + "content": [{"type": "text", "text": "early FE write"}], + "turn_id": turn_id, + } + fe_response = await new_chat_routes.append_message( + thread_id=thread_id, + request=_FakeRequest(fe_request_body), + session=db_session, + user=db_user, + ) + assert fe_response.role == NewChatMessageRole.ASSISTANT + + # Step 2: server stream's persist_assistant_shell now races + # behind. It must adopt the existing row id, not raise. + adopted_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert adopted_id == fe_response.id + + # Step 3: server finalize then overwrites the FE's stub with + # the rich content (which is the correct, more authoritative + # payload). + server_content = _build_tool_heavy_content() + await finalize_assistant_turn( + message_id=adopted_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=server_content, + accumulator=_accumulator_with_one_call(), + ) + + # Final state: one row, server content, one token_usage row. + msg_count = ( + await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + ).scalar_one() + assert msg_count == 1 + + row = await db_session.get(NewChatMessage, adopted_id) + await db_session.refresh(row) + assert row.content == server_content + + usage_count = ( + await db_session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == adopted_id) + ) + ).scalar_one() + assert usage_count == 1 + + async def test_appendmessage_without_turn_id_legacy_400( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + bypass_permission_checks, + ): + """Defensive: a bare appendMessage with no turn_id and no + existing row is just a normal INSERT — must succeed. But if a + row with the same role already exists in this thread *without* + turn_id collisions, the route should fall through to the + legacy 400 path on a foreign-key / unrelated IntegrityError + (we don't ship that bug today, but pin the behaviour so a + future schema change can't silently regress it). + """ + thread_id = db_thread.id + + # Bare appendMessage with no turn_id — should just succeed + # without invoking the recovery branch. + ok_response = await new_chat_routes.append_message( + thread_id=thread_id, + request=_FakeRequest( + { + "role": "user", + "content": [{"type": "text", "text": "hi"}], + } + ), + session=db_session, + user=db_user, + ) + assert ok_response.role == NewChatMessageRole.USER + assert ok_response.turn_id is None + + # Sanity: the route did NOT silently swallow the missing + # turn_id by routing through the unique-index recovery branch + # — it took the happy path. + msg_count = ( + await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + ).scalar_one() + assert msg_count == 1 diff --git a/surfsense_backend/tests/integration/chat/test_message_id_sse.py b/surfsense_backend/tests/integration/chat/test_message_id_sse.py new file mode 100644 index 000000000..8fc935eaa --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_message_id_sse.py @@ -0,0 +1,332 @@ +"""Integration tests for the SSE-based message ID handshake. + +The streaming generators (``stream_new_chat`` / ``stream_resume_chat``) +emit two new events after their respective persistence helpers resolve +the canonical ``new_chat_messages.id``: + +* ``data-user-message-id`` — emitted only by ``stream_new_chat``, + AFTER ``persist_user_turn`` and BEFORE any LLM streaming. +* ``data-assistant-message-id`` — emitted by both + ``stream_new_chat`` and ``stream_resume_chat``, AFTER + ``persist_assistant_shell`` and BEFORE any LLM streaming. + +The frontend renames its optimistic ``msg-user-XXX`` / +``msg-assistant-XXX`` placeholder ids to ``msg-{db_id}`` upon receiving +these events. This test suite anchors three contracts: + +1. ``format_data`` produces SSE bytes in the precise shape + ``data: {"type":"data-","data":{...}}\\n\\n`` that the FE's + ``readSSEStream`` consumer parses (matches ``surfsense_web/lib/chat/streaming-state.ts``). +2. The ``message_id`` carried in the SSE payload exactly equals the + primary key the persistence helper inserted into + ``new_chat_messages`` — so the FE rename produces ``msg-{real_pk}``, + which in turn unlocks DB-id-gated UI (comments, edit-from-message). +3. The same ``message_id`` is used for the ``token_usage.message_id`` + foreign key, so ``finalize_assistant_turn``'s row binds correctly. + +Direct end-to-end testing of ``stream_new_chat`` requires a fully +mocked agent + LLM stack (out-of-scope here); those flows are covered +by the harness-driven integration tests under +``tests/integration/agents/new_chat/`` plus the assertion in +``test_persistence.py`` that the helpers themselves return ``int`` +ids. The contracts above close the remaining gap between the persist +helpers and the bytes that ship to the FE. +""" + +from __future__ import annotations + +import json +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatMessageRole, + NewChatThread, + SearchSpace, + User, +) +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat import persistence as persistence_module +from app.tasks.chat.persistence import ( + persist_assistant_shell, + persist_user_turn, +) + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------------------- +# Fixtures (mirror test_persistence.py) +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_thread( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +) -> NewChatThread: + thread = NewChatThread( + title="Test Chat", + search_space_id=db_search_space.id, + created_by_id=db_user.id, + visibility=ChatVisibility.PRIVATE, + ) + db_session.add(thread) + await db_session.flush() + return thread + + +@pytest.fixture +def patched_shielded_session(monkeypatch, db_session: AsyncSession): + """Route persistence helpers to the test's savepoint-bound session.""" + + @asynccontextmanager + async def _fake_shielded_session(): + yield db_session + + monkeypatch.setattr( + persistence_module, + "shielded_async_session", + _fake_shielded_session, + ) + return db_session + + +# --------------------------------------------------------------------------- +# (1) SSE byte-shape contract +# --------------------------------------------------------------------------- + + +def _parse_sse_data_line(blob: str) -> dict: + """Unwrap a single ``data: \\n\\n`` SSE frame. + + Raises if there's more than one frame or the prefix is wrong — keeps + the parser strict so a regression in ``format_data`` produces a + test failure here, not in a downstream consumer. + """ + assert blob.endswith("\n\n"), f"missing terminator: {blob!r}" + line = blob.removesuffix("\n\n") + assert line.startswith("data: "), f"missing data prefix: {line!r}" + return json.loads(line.removeprefix("data: ")) + + +class TestSSEByteShape: + def test_data_user_message_id_byte_shape(self): + """``format_data("user-message-id", {...})`` must produce the + exact wire format the FE's + ``readSSEStream`` -> ``data-user-message-id`` case parses. + """ + svc = VercelStreamingService() + blob = svc.format_data( + "user-message-id", + {"message_id": 1843, "turn_id": "533:1762900000000"}, + ) + envelope = _parse_sse_data_line(blob) + assert envelope == { + "type": "data-user-message-id", + "data": {"message_id": 1843, "turn_id": "533:1762900000000"}, + } + + def test_data_assistant_message_id_byte_shape(self): + svc = VercelStreamingService() + blob = svc.format_data( + "assistant-message-id", + {"message_id": 1844, "turn_id": "533:1762900000000"}, + ) + envelope = _parse_sse_data_line(blob) + assert envelope == { + "type": "data-assistant-message-id", + "data": {"message_id": 1844, "turn_id": "533:1762900000000"}, + } + + +# --------------------------------------------------------------------------- +# (2) Helper-id <-> DB-pk coherence +# --------------------------------------------------------------------------- + + +class TestHandshakeIdMatchesDB: + """The SSE handshake's correctness hinges on the integer in + ``data-{user,assistant}-message-id`` being the EXACT primary key + the persistence helper inserted. If they ever diverge, the FE + rename produces ``msg-{wrong_id}``, comments break (regex match + fails), and downstream features (edit, regenerate) target the + wrong row. Anchor it here. + """ + + async def test_user_message_id_matches_new_chat_messages_pk( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:9000" + + # The streaming generator passes this same value into + # ``streaming_service.format_data("user-message-id", {...})``. + msg_id_from_helper = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + assert isinstance(msg_id_from_helper, int) + + # Look up the row the helper inserted via + # ``(thread_id, turn_id, role)`` — the same composite the FE + # uses to identify a turn — and confirm the PK matches. + row = ( + await db_session.execute( + select(NewChatMessage).where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + ).scalar_one() + assert row.id == msg_id_from_helper + + # The byte-stream the FE actually receives — confirms the + # round-trip from the helper return value to the SSE payload. + svc = VercelStreamingService() + envelope = _parse_sse_data_line( + svc.format_data( + "user-message-id", + {"message_id": msg_id_from_helper, "turn_id": turn_id}, + ) + ) + assert envelope["data"]["message_id"] == row.id + + async def test_assistant_message_id_matches_new_chat_messages_pk( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:9100" + + msg_id_from_helper = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert isinstance(msg_id_from_helper, int) + + row = ( + await db_session.execute( + select(NewChatMessage).where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + ).scalar_one() + assert row.id == msg_id_from_helper + + svc = VercelStreamingService() + envelope = _parse_sse_data_line( + svc.format_data( + "assistant-message-id", + {"message_id": msg_id_from_helper, "turn_id": turn_id}, + ) + ) + assert envelope["data"]["message_id"] == row.id + + async def test_handshake_ids_for_full_turn_are_distinct_and_paired( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Sanity: a full new-chat turn's two SSE events carry two + DIFFERENT ids (user row PK ≠ assistant row PK), both anchored + to the SAME ``turn_id`` in the DB. This pairing is what + ``finalize_assistant_turn`` and the regenerate / edit flows + rely on. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:9200" + + user_msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hi", + ) + assistant_msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert user_msg_id is not None and assistant_msg_id is not None + assert user_msg_id != assistant_msg_id + + rows = ( + ( + await db_session.execute( + select(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + ) + .order_by(NewChatMessage.id) + ) + ) + .scalars() + .all() + ) + assert len(rows) == 2 + ids_by_role = {r.role: r.id for r in rows} + assert ids_by_role[NewChatMessageRole.USER] == user_msg_id + assert ids_by_role[NewChatMessageRole.ASSISTANT] == assistant_msg_id + + +# --------------------------------------------------------------------------- +# (3) Parse helpers used by the FE — sanity-check our payload shape +# --------------------------------------------------------------------------- + + +class TestPayloadShapeMatchesFEReader: + """The FE's ``readStreamedMessageId`` (in + ``surfsense_web/lib/chat/stream-side-effects.ts``) requires: + + * ``message_id`` is a ``number`` (rejects null / string / NaN). + * ``turn_id`` is an optional non-empty string (else it's coerced + to ``null``). + + These tests exercise the BE side of that contract by inspecting + ``format_data`` output shapes that the FE consumes verbatim. + """ + + def test_message_id_is_serialised_as_a_json_number(self): + svc = VercelStreamingService() + envelope = _parse_sse_data_line( + svc.format_data("user-message-id", {"message_id": 42, "turn_id": "t"}) + ) + assert isinstance(envelope["data"]["message_id"], int) + assert envelope["data"]["message_id"] == 42 + + def test_turn_id_round_trips_as_string(self): + svc = VercelStreamingService() + # The actual format used in production: f"{chat_id}:{int(time.time()*1000)}" + production_turn_id = "533:1762900000000" + envelope = _parse_sse_data_line( + svc.format_data( + "assistant-message-id", + {"message_id": 1, "turn_id": production_turn_id}, + ) + ) + assert envelope["data"]["turn_id"] == production_turn_id diff --git a/surfsense_backend/tests/integration/chat/test_persistence.py b/surfsense_backend/tests/integration/chat/test_persistence.py new file mode 100644 index 000000000..66a04772e --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_persistence.py @@ -0,0 +1,747 @@ +"""Integration tests for ``app.tasks.chat.persistence``. + +Verifies the DB-side guarantees the streaming chat task relies on: + +* ``persist_assistant_shell`` is idempotent against the + ``(thread_id, turn_id, ASSISTANT)`` partial unique index from + migration 141. Two calls with the same ``turn_id`` return the SAME + ``message_id`` and never create a duplicate ``new_chat_messages`` row. +* ``finalize_assistant_turn`` writes a status-marker payload when given + empty content, never raises, and is safe to call twice on the same + ``message_id`` — the partial unique index from migration 142 + (``uq_token_usage_message_id``) prevents the second insert from + producing a duplicate ``token_usage`` row. +* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer + race where ``finalize_assistant_turn`` and the ``append_message`` + recovery branch both target the same ``message_id``. + +All tests run inside the conftest's outer-transaction-with-savepoint +fixture so commits inside the helpers (which open their own +``shielded_async_session``) are released as savepoints and rolled back +at test end. We monkey-patch ``shielded_async_session`` to yield the +same pooled test session so the integration transaction stays +in-scope. +""" + +from __future__ import annotations + +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from sqlalchemy import func, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatMessageRole, + NewChatThread, + SearchSpace, + TokenUsage, + User, +) +from app.services.token_tracking_service import TurnTokenAccumulator +from app.tasks.chat import persistence as persistence_module +from app.tasks.chat.persistence import ( + finalize_assistant_turn, + persist_assistant_shell, + persist_user_turn, +) + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_thread( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +) -> NewChatThread: + thread = NewChatThread( + title="Test Chat", + search_space_id=db_search_space.id, + created_by_id=db_user.id, + visibility=ChatVisibility.PRIVATE, + ) + db_session.add(thread) + await db_session.flush() + return thread + + +@pytest.fixture +def patched_shielded_session(monkeypatch, db_session: AsyncSession): + """Route persistence helpers to the test's savepoint-bound session. + + The persistence helpers use ``async with shielded_async_session() as + ws`` and call ``ws.commit()`` internally. Inside the conftest's + ``join_transaction_mode="create_savepoint"`` setup, those commits + release a SAVEPOINT instead of committing the outer transaction — + so the test session can see helper-staged rows immediately and the + outer rollback at end of test wipes them. + """ + + @asynccontextmanager + async def _fake_shielded_session(): + yield db_session + # Do NOT close — the outer fixture owns the session lifecycle. + + monkeypatch.setattr( + persistence_module, + "shielded_async_session", + _fake_shielded_session, + ) + return db_session + + +def _accumulator_with_one_call() -> TurnTokenAccumulator: + acc = TurnTokenAccumulator() + acc.add( + model="gpt-4o-mini", + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cost_micros=12345, + ) + return acc + + +async def _count_assistant_rows( + session: AsyncSession, thread_id: int, turn_id: str +) -> int: + result = await session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + return int(result.scalar_one()) + + +async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int: + result = await session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == message_id) + ) + return int(result.scalar_one()) + + +# --------------------------------------------------------------------------- +# persist_assistant_shell +# --------------------------------------------------------------------------- + + +class TestPersistAssistantShell: + async def test_first_call_inserts_empty_shell_and_returns_id( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + # Capture primitive ids before any persistence helper runs: + # the helpers commit/rollback the shared test session, which + # can detach ORM rows mid-test. + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:1000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None and isinstance(msg_id, int) + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + assert row.thread_id == thread_id + assert row.role == NewChatMessageRole.ASSISTANT + assert row.turn_id == turn_id + # Empty shell payload — finalize_assistant_turn overwrites later. + assert row.content == [{"type": "text", "text": ""}] + + async def test_second_call_with_same_turn_id_returns_same_id( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + # Capture primitive ids before any persistence helper runs: + # the helpers commit/rollback the shared test session, which + # can detach ORM rows mid-test. + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:2000" + + first_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + second_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + + assert first_id is not None + assert first_id == second_id + # Exactly one row in the DB for this turn. + assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1 + + async def test_missing_turn_id_returns_none( + self, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id="", + ) + assert msg_id is None + + async def test_after_persist_user_turn_resolves_assistant_id( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:3000" + + # The streaming layer always calls persist_user_turn first, so + # smoke-test the canonical sequence. + user_msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + assert isinstance(user_msg_id, int) + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + # User row + assistant shell row = 2 rows for this turn. + result = await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + ) + ) + assert result.scalar_one() == 2 + + async def test_double_call_with_same_turn_id_uses_on_conflict( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Verifies the ON CONFLICT DO NOTHING path on the assistant + shell does not raise ``IntegrityError`` even when the second + writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id`` + already covers the same-id semantics; this test additionally + asserts neither call raises so the debugger's + ``raise-on-IntegrityError`` setting won't pause the streaming + path under contention. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:3500" + + # Both calls go through ``pg_insert(...).on_conflict_do_nothing``; + # the second one returns RETURNING=∅ and falls into the SELECT + # branch. Neither path raises. + first_id = await persist_assistant_shell( + chat_id=thread_id, user_id=user_id_str, turn_id=turn_id + ) + second_id = await persist_assistant_shell( + chat_id=thread_id, user_id=user_id_str, turn_id=turn_id + ) + assert first_id is not None + assert first_id == second_id + + +# --------------------------------------------------------------------------- +# persist_user_turn +# --------------------------------------------------------------------------- + + +class TestPersistUserTurn: + async def test_returns_message_id_on_first_insert( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8000" + + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + assert isinstance(msg_id, int) and msg_id > 0 + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + assert row.thread_id == thread_id + assert row.role == NewChatMessageRole.USER + assert row.turn_id == turn_id + assert row.content == [{"type": "text", "text": "hello"}] + + async def test_returns_existing_id_on_conflict( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8100" + + first_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + # Second call simulates a legacy FE ``appendMessage`` racing the + # SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB + # level, the helper recovers the existing id via SELECT, and + # crucially does NOT raise ``IntegrityError`` (the debugger + # would otherwise pause on it). + second_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="ignored on conflict", + ) + assert first_id is not None + assert first_id == second_id + + # Exactly one user row for this turn. + count = await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + assert count.scalar_one() == 1 + + async def test_embeds_mentioned_documents_part( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """The full ``{id, title, document_type}`` triple forwarded by + the FE must round-trip into a single ``mentioned-documents`` + ContentPart on the persisted user message — the history loader + renders the chips on reload from this part directly. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8200" + + mentioned = [ + {"id": 11, "title": "Alpha", "document_type": "GENERAL"}, + {"id": 22, "title": "Beta", "document_type": "GENERAL"}, + ] + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + mentioned_documents=mentioned, + ) + assert isinstance(msg_id, int) + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + # Content is a 2-part list: text + mentioned-documents. + assert isinstance(row.content, list) + assert row.content[0] == {"type": "text", "text": "hello"} + assert row.content[1] == { + "type": "mentioned-documents", + "documents": [ + {"id": 11, "title": "Alpha", "document_type": "GENERAL"}, + {"id": 22, "title": "Beta", "document_type": "GENERAL"}, + ], + } + + async def test_skips_mentioned_documents_when_empty_or_invalid( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Empty list and entries missing required fields are dropped; + a ``mentioned-documents`` part is only emitted when at least + one normalised entry survived. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id_empty = f"{thread_id}:8300" + turn_id_invalid = f"{thread_id}:8301" + + msg_id_empty = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id_empty, + user_query="hi", + mentioned_documents=[], + ) + assert isinstance(msg_id_empty, int) + row_empty = await db_session.get(NewChatMessage, msg_id_empty) + assert row_empty is not None + assert row_empty.content == [{"type": "text", "text": "hi"}] + + # Each entry missing one required field — all skipped. + msg_id_invalid = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id_invalid, + user_query="hi", + mentioned_documents=[ + {"title": "no id", "document_type": "GENERAL"}, # missing id + {"id": 99, "document_type": "GENERAL"}, # missing title + {"id": 100, "title": "no type"}, # missing document_type + ], + ) + assert isinstance(msg_id_invalid, int) + row_invalid = await db_session.get(NewChatMessage, msg_id_invalid) + assert row_invalid is not None + assert row_invalid.content == [{"type": "text", "text": "hi"}] + + async def test_missing_turn_id_returns_none( + self, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id="", + user_query="hello", + ) + assert msg_id is None + + +# --------------------------------------------------------------------------- +# finalize_assistant_turn +# --------------------------------------------------------------------------- + + +class TestFinalizeAssistantTurn: + async def test_writes_content_and_token_usage( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_uuid = db_user.id + user_id_str = str(user_id_uuid) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:4000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + rich_content = [ + {"type": "text", "text": "Hello world"}, + { + "type": "tool-call", + "toolCallId": "call_x", + "toolName": "ls", + "args": {"path": "/"}, + "argsText": '{\n "path": "/"\n}', + "result": {"files": []}, + "langchainToolCallId": "lc_x", + }, + ] + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=rich_content, + accumulator=_accumulator_with_one_call(), + ) + + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == rich_content + + # Exactly one token_usage row keyed on this message_id. + usage_rows = ( + ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ) + .scalars() + .all() + ) + assert len(usage_rows) == 1 + usage = usage_rows[0] + assert usage.usage_type == "chat" + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.cost_micros == 12345 + assert usage.thread_id == thread_id + assert usage.search_space_id == search_space_id + + async def test_empty_content_writes_status_marker( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:5000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + # Pure tool-call turn that aborted before any output, or + # interrupt before any event arrived — empty list. + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[], + accumulator=None, + ) + + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == [{"type": "status", "text": "(no text response)"}] + + async def test_double_call_safe_via_on_conflict( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:6000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + first_acc = _accumulator_with_one_call() + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[{"type": "text", "text": "first finalize"}], + accumulator=first_acc, + ) + + # Simulate a follow-up finalize (e.g., resume retry within the + # shielded finally block firing twice). Different content, but + # ON CONFLICT DO NOTHING on token_usage means the cost from the + # first finalize stays authoritative. + second_acc = TurnTokenAccumulator() + second_acc.add( + model="gpt-4o-mini", + prompt_tokens=999, + completion_tokens=999, + total_tokens=1998, + cost_micros=99999, + ) + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[{"type": "text", "text": "second finalize"}], + accumulator=second_acc, + ) + + # Content was overwritten by the second UPDATE. + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == [{"type": "text", "text": "second finalize"}] + + # But token_usage stayed at exactly one row, preserving the + # first finalize's authoritative cost. + assert await _count_token_usage_rows(db_session, msg_id) == 1 + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.cost_micros == 12345 # First finalize's value + + async def test_append_message_style_insert_after_finalize_no_dupe( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + """Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``. + + Both target the same ``message_id``; the partial unique index + ``uq_token_usage_message_id`` (migration 142) makes the second + insert a no-op via ``ON CONFLICT DO NOTHING``. + """ + from sqlalchemy import text as sa_text + + thread_id = db_thread.id + user_uuid = db_user.id + user_id_str = str(user_uuid) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:7000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[{"type": "text", "text": "from server"}], + accumulator=_accumulator_with_one_call(), + ) + + # Now simulate the FE's append_message branch firing AFTER — + # the same INSERT ... ON CONFLICT DO NOTHING shape used by the + # route handler, keyed on the migration-142 partial unique + # index. + late_insert = ( + pg_insert(TokenUsage) + .values( + usage_type="chat", + prompt_tokens=42, + completion_tokens=42, + total_tokens=84, + cost_micros=1, + model_breakdown=None, + call_details=None, + thread_id=thread_id, + message_id=msg_id, + search_space_id=search_space_id, + user_id=user_uuid, + ) + .on_conflict_do_nothing( + index_elements=["message_id"], + index_where=sa_text("message_id IS NOT NULL"), + ) + ) + await db_session.execute(late_insert) + await db_session.flush() + + # Still exactly one row, with the original (server) cost value. + assert await _count_token_usage_rows(db_session, msg_id) == 1 + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.cost_micros == 12345 + + async def test_helper_never_raises_on_missing_message_id( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + + # message_id that doesn't exist — finalize must log+return, + # never raise (called from shielded finally). + await finalize_assistant_turn( + message_id=999_999_999, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id="anything", + content=[{"type": "text", "text": "x"}], + accumulator=_accumulator_with_one_call(), + ) + # If we got here without an exception, the test passes. + # Sanity: no token_usage row created (FK to message would have + # been rejected anyway, but ON CONFLICT path may swallow + # FK errors as well; check directly). + assert await _count_token_usage_rows(db_session, 999_999_999) == 0 diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index 397b1c787..36fe04aa2 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -226,6 +226,31 @@ class TestCompose: # Default block should NOT be present assert "" not in prompt + def test_provider_hints_render_with_custom_system_instructions( + self, fixed_today: datetime + ) -> None: + """Regression guard for the always-append decision: provider hints + append AFTER a custom system prompt. + + Provider hints are stylistic nudges (parallel tool-call rules, + formatting guidance, etc.) that help the model regardless of + what the system instructions say. Suppressing them when a + custom prompt is set would partially defeat the per-family + prompt machinery. + """ + prompt = compose_system_prompt( + today=fixed_today, + custom_system_instructions="You are a custom assistant.", + model_name="anthropic/claude-3-5-sonnet", + ) + assert "You are a custom assistant." in prompt + assert "" in prompt + # The custom prompt must come BEFORE the provider hints so the + # user's framing isn't drowned out by the stylistic nudges. + assert prompt.index("You are a custom assistant.") < prompt.index( + "" + ) + def test_use_default_false_with_no_custom_yields_no_system_block( self, fixed_today: datetime ) -> None: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py new file mode 100644 index 000000000..9b3de2db7 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py @@ -0,0 +1,268 @@ +"""Regression tests for the compiled-agent cache. + +Covers the cache primitive itself (TTL, LRU, in-flight de-duplication, +build-failure non-caching) and the cache-key signature helpers that +``create_surfsense_deep_agent`` relies on. The integration with +``create_surfsense_deep_agent`` is covered separately by the streaming +contract tests; this module focuses on the primitives so a regression +in the cache implementation is caught before it reaches the agent +factory. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + +import pytest + +from app.agents.new_chat.agent_cache import ( + flags_signature, + reload_for_tests, + stable_hash, + system_prompt_hash, + tools_signature, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# stable_hash + signature helpers +# --------------------------------------------------------------------------- + + +def test_stable_hash_is_deterministic_across_calls() -> None: + a = stable_hash("v1", 42, "thread-9", None, ["x", "y"]) + b = stable_hash("v1", 42, "thread-9", None, ["x", "y"]) + assert a == b + + +def test_stable_hash_changes_when_any_part_changes() -> None: + base = stable_hash("v1", 42, "thread-9") + assert stable_hash("v1", 42, "thread-10") != base + assert stable_hash("v2", 42, "thread-9") != base + assert stable_hash("v1", 43, "thread-9") != base + + +def test_tools_signature_keys_on_name_and_description_not_identity() -> None: + """Two tool lists with the same surface must hash identically. + + The cache key MUST NOT change when the underlying ``BaseTool`` + instances are different Python objects (a fresh request constructs + fresh tool instances every time). Hashing on ``(name, description)`` + keeps the cache hot across requests with identical tool surfaces. + """ + + @dataclass + class FakeTool: + name: str + description: str + + tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")] + tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")] + sig_a = tools_signature( + tools_a, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + sig_b = tools_signature( + tools_b, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + assert sig_a == sig_b, "tool order must not affect the signature" + + # Adding a tool rotates the key. + tools_c = [*tools_a, FakeTool("gamma", "does gamma")] + sig_c = tools_signature( + tools_c, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + assert sig_c != sig_a + + +def test_tools_signature_rotates_when_connector_set_changes() -> None: + @dataclass + class FakeTool: + name: str + description: str + + tools = [FakeTool("a", "x")] + base = tools_signature( + tools, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + added = tools_signature( + tools, + available_connectors=["NOTION", "SLACK"], + available_document_types=["FILE"], + ) + assert base != added, "adding a connector must rotate the cache key" + + +def test_flags_signature_changes_when_flag_flips() -> None: + @dataclass(frozen=True) + class Flags: + a: bool = True + b: bool = False + + base = flags_signature(Flags()) + flipped = flags_signature(Flags(b=True)) + assert base != flipped + + +def test_system_prompt_hash_is_stable_and_distinct() -> None: + p1 = "You are a helpful assistant." + p2 = "You are a helpful assistant!" # one-character delta + assert system_prompt_hash(p1) == system_prompt_hash(p1) + assert system_prompt_hash(p1) != system_prompt_hash(p2) + + +# --------------------------------------------------------------------------- +# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cache_hit_returns_same_instance_on_second_call() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + builds = 0 + + async def builder() -> object: + nonlocal builds + builds += 1 + return object() + + a = await cache.get_or_build("k", builder=builder) + b = await cache.get_or_build("k", builder=builder) + assert a is b, "cache must return the SAME object across hits" + assert builds == 1, "builder must run exactly once" + + +@pytest.mark.asyncio +async def test_cache_different_keys_get_different_instances() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("k1", builder=builder) + b = await cache.get_or_build("k2", builder=builder) + assert a is not b + + +@pytest.mark.asyncio +async def test_cache_stale_entries_get_rebuilt() -> None: + # ttl=0 means every read sees the entry as immediately stale. + cache = reload_for_tests(maxsize=8, ttl_seconds=0.0) + builds = 0 + + async def builder() -> object: + nonlocal builds + builds += 1 + return object() + + a = await cache.get_or_build("k", builder=builder) + b = await cache.get_or_build("k", builder=builder) + assert a is not b, "stale entry must rebuild a fresh instance" + assert builds == 2 + + +@pytest.mark.asyncio +async def test_cache_evicts_lru_when_full() -> None: + cache = reload_for_tests(maxsize=2, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("a", builder=builder) + _ = await cache.get_or_build("b", builder=builder) + # Re-touch "a" so "b" is now the LRU victim. + a_again = await cache.get_or_build("a", builder=builder) + assert a_again is a + # Inserting "c" should evict "b" (LRU), not "a". + _ = await cache.get_or_build("c", builder=builder) + assert cache.stats()["size"] == 2 + + # Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild). + a_hit = await cache.get_or_build("a", builder=builder) + assert a_hit is a, "LRU must keep the most-recently-used 'a' entry" + + +@pytest.mark.asyncio +async def test_cache_concurrent_misses_coalesce_to_single_build() -> None: + """Two concurrent get_or_build calls on the same key must share one builder.""" + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + build_started = asyncio.Event() + builds = 0 + + async def slow_builder() -> object: + nonlocal builds + builds += 1 + build_started.set() + # Yield control so the second waiter can race against us. + await asyncio.sleep(0.05) + return object() + + task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder)) + # Wait until the first builder has started, then race a second waiter. + await build_started.wait() + task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder)) + + a, b = await asyncio.gather(task_a, task_b) + assert a is b, "coalesced waiters must observe the same value" + assert builds == 1, "concurrent cold misses must collapse to ONE build" + + +@pytest.mark.asyncio +async def test_cache_does_not_store_failed_builds() -> None: + """A builder that raises must NOT poison the cache. + + The next caller for the same key must run the builder again (not + re-raise the cached exception). + """ + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + attempts = 0 + + async def flaky_builder() -> object: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise RuntimeError("transient") + return object() + + with pytest.raises(RuntimeError, match="transient"): + await cache.get_or_build("k", builder=flaky_builder) + + # Second call must retry — not re-raise the cached exception. + value = await cache.get_or_build("k", builder=flaky_builder) + assert value is not None + assert attempts == 2 + + +@pytest.mark.asyncio +async def test_cache_invalidate_drops_entry() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("k", builder=builder) + assert cache.invalidate("k") is True + b = await cache.get_or_build("k", builder=builder) + assert a is not b, "post-invalidation lookup must rebuild" + + +@pytest.mark.asyncio +async def test_cache_invalidate_prefix_drops_matching_entries() -> None: + cache = reload_for_tests(maxsize=16, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + await cache.get_or_build("user:1:thread:1", builder=builder) + await cache.get_or_build("user:1:thread:2", builder=builder) + await cache.get_or_build("user:2:thread:1", builder=builder) + + removed = cache.invalidate_prefix("user:1:") + assert removed == 2 + assert cache.stats()["size"] == 1 + + # The user:2 entry must still be hot (no rebuild). + survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder) + assert survivor_value is not None diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index 0c7bf17f6..f0161f605 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -7,7 +7,9 @@ import pytest from app.agents.new_chat.errors import BusyError from app.agents.new_chat.middleware.busy_mutex import ( BusyMutexMiddleware, + end_turn, get_cancel_event, + is_cancel_requested, manager, request_cancel, reset_cancel, @@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None: def test_reset_cancel_idempotent() -> None: # Should not raise even if event was never created reset_cancel("never-seen") + + +def test_request_cancel_creates_event_for_unseen_thread() -> None: + thread_id = "never-seen-cancel" + reset_cancel(thread_id) + + assert request_cancel(thread_id) is True + assert get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is True + + +@pytest.mark.asyncio +async def test_end_turn_force_clears_lock_and_cancel_state() -> None: + thread_id = "forced-end-turn" + mw = BusyMutexMiddleware() + runtime = _Runtime(thread_id) + + await mw.abefore_agent({}, runtime) + assert manager.lock_for(thread_id).locked() + + request_cancel(thread_id) + assert is_cancel_requested(thread_id) is True + + end_turn(thread_id) + + assert not manager.lock_for(thread_id).locked() + assert not get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is False + + +@pytest.mark.asyncio +async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None: + """A stale aafter call from attempt A must not unlock attempt B. + + Repro flow: + 1) attempt A acquires thread lock + 2) forced end_turn clears A so retry can proceed + 3) attempt B acquires same thread lock + 4) stale attempt-A aafter runs late + + Expected: B lock remains held. + """ + thread_id = "stale-aafter-lock" + runtime = _Runtime(thread_id) + attempt_a = BusyMutexMiddleware() + attempt_b = BusyMutexMiddleware() + + await attempt_a.abefore_agent({}, runtime) + lock = manager.lock_for(thread_id) + assert lock.locked() + + end_turn(thread_id) + assert not lock.locked() + + await attempt_b.abefore_agent({}, runtime) + assert lock.locked() + + # Stale cleanup from attempt A must not release attempt B's lock. + await attempt_a.aafter_agent({}, runtime) + assert lock.locked() + + await attempt_b.aafter_agent({}, runtime) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index 38a70a443..6800be2af 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "SURFSENSE_ENABLE_ACTION_LOG", "SURFSENSE_ENABLE_REVERT_ROUTE", + "SURFSENSE_ENABLE_STREAM_PARITY_V2", "SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_OTEL", + "SURFSENSE_ENABLE_AGENT_CACHE", + "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", ]: monkeypatch.delenv(name, raising=False) -def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None: +def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None: _clear_all(monkeypatch) flags = reload_for_tests() assert isinstance(flags, AgentFeatureFlags) assert flags.disable_new_agent_stack is False - assert flags.any_new_middleware_enabled() is False + assert flags.enable_context_editing is True + assert flags.enable_compaction_v2 is True + assert flags.enable_retry_after is True + assert flags.enable_model_fallback is False + assert flags.enable_model_call_limit is True + assert flags.enable_tool_call_limit is True + assert flags.enable_tool_call_repair is True + assert flags.enable_doom_loop is True + assert flags.enable_permission is True + assert flags.enable_busy_mutex is True + assert flags.enable_llm_tool_selector is False + assert flags.enable_skills is True + assert flags.enable_specialized_subagents is True + assert flags.enable_kb_planner_runnable is True + assert flags.enable_action_log is True + assert flags.enable_revert_route is True + assert flags.enable_stream_parity_v2 is True + assert flags.enable_plugin_loader is False + assert flags.enable_otel is False + # Phase 2: agent cache is now default-on (the prerequisite tool + # ``db_session`` refactor landed). The companion gp-subagent share + # flag stays default-off pending data on cold-miss frequency. + assert flags.enable_agent_cache is True + assert flags.enable_agent_cache_share_gp_subagent is False + assert flags.any_new_middleware_enabled() is True def test_master_kill_switch_overrides_individual_flags( @@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", + "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2", "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", "enable_otel": "SURFSENSE_ENABLE_OTEL", } - # `enable_otel` is intentionally orthogonal — it does NOT count toward - # ``any_new_middleware_enabled`` because OTel is observability-only and - # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement. - counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"} - for attr, env_name in flag_to_env.items(): _clear_all(monkeypatch) - monkeypatch.setenv(env_name, "true") + monkeypatch.setenv(env_name, "false") flags = reload_for_tests() - assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}" - if attr in counts_toward_middleware: - assert flags.any_new_middleware_enabled() is True - else: - assert flags.any_new_middleware_enabled() is False + assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py new file mode 100644 index 000000000..6c323d920 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py @@ -0,0 +1,344 @@ +"""Tests for ``FlattenSystemMessageMiddleware``. + +The middleware exists to defend against Anthropic's "Found 5 cache_control +blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on +the system message and the OpenRouter→Anthropic adapter redistributes +``cache_control`` across all of them. The flattening collapses every +all-text system content list to a single string before the LLM call. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import HumanMessage, SystemMessage + +from app.agents.new_chat.middleware.flatten_system import ( + FlattenSystemMessageMiddleware, + _flatten_text_blocks, + _flattened_request, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# _flatten_text_blocks — pure helper, the heart of the middleware. +# --------------------------------------------------------------------------- + + +class TestFlattenTextBlocks: + def test_joins_text_blocks_with_double_newline(self) -> None: + blocks = [ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + assert ( + _flatten_text_blocks(blocks) + == "\n\n\n\n" + ) + + def test_handles_single_text_block(self) -> None: + blocks = [{"type": "text", "text": "only one"}] + assert _flatten_text_blocks(blocks) == "only one" + + def test_handles_empty_list(self) -> None: + assert _flatten_text_blocks([]) == "" + + def test_passes_through_bare_string_blocks(self) -> None: + # LangChain content can mix bare strings and dict blocks. + blocks = ["raw string", {"type": "text", "text": "dict block"}] + assert _flatten_text_blocks(blocks) == "raw string\n\ndict block" + + def test_returns_none_for_image_block(self) -> None: + # System messages with images are rare — but we never want to + # silently lose the image payload by joining as text. + blocks = [ + {"type": "text", "text": "look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png..."}}, + ] + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_for_non_dict_non_str_block(self) -> None: + blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item] + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_when_text_field_missing(self) -> None: + blocks = [{"type": "text"}] # no ``text`` key + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_when_text_is_not_string(self) -> None: + blocks = [{"type": "text", "text": ["nested", "list"]}] + assert _flatten_text_blocks(blocks) is None + + def test_drops_cache_control_from_inner_blocks(self) -> None: + # The whole point: existing cache_control on inner blocks is + # discarded so LiteLLM's ``cache_control_injection_points`` can + # re-attach exactly one breakpoint after flattening. + blocks = [ + {"type": "text", "text": "first"}, + { + "type": "text", + "text": "second", + "cache_control": {"type": "ephemeral"}, + }, + ] + flattened = _flatten_text_blocks(blocks) + assert flattened == "first\n\nsecond" + assert "cache_control" not in flattened # type: ignore[operator] + + +# --------------------------------------------------------------------------- +# _flattened_request — decides when to override and when to no-op. +# --------------------------------------------------------------------------- + + +def _make_request(system_message: SystemMessage | None) -> Any: + """Build a minimal ModelRequest stub. We only need .system_message + and .override(system_message=...) — the middleware never touches + other fields. + """ + request = MagicMock() + request.system_message = system_message + + def override(**kwargs: Any) -> Any: + new_request = MagicMock() + new_request.system_message = kwargs.get( + "system_message", request.system_message + ) + new_request.messages = kwargs.get("messages", getattr(request, "messages", [])) + new_request.tools = kwargs.get("tools", getattr(request, "tools", [])) + return new_request + + request.override = override + return request + + +class TestFlattenedRequest: + def test_collapses_multi_block_system_to_string(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + ) + request = _make_request(sys) + flattened = _flattened_request(request) + + assert flattened is not None + assert isinstance(flattened.system_message, SystemMessage) + assert flattened.system_message.content == ( + "\n\n\n\n\n\n\n\n" + ) + + def test_no_op_for_string_content(self) -> None: + sys = SystemMessage(content="already a string") + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_no_op_for_single_block_list(self) -> None: + # One block already produces one breakpoint — no need to flatten. + sys = SystemMessage(content=[{"type": "text", "text": "single"}]) + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_no_op_when_system_message_missing(self) -> None: + request = _make_request(None) + assert _flattened_request(request) is None + + def test_no_op_when_list_contains_non_text_block(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + ] + ) + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_preserves_additional_kwargs_and_metadata(self) -> None: + # Defensive: nothing in the current chain sets these on a system + # message, but losing them silently when something does in the + # future would be a regression. ``name`` in particular is the only + # ``additional_kwargs`` field that ChatLiteLLM's + # ``_convert_message_to_dict`` propagates onto the wire. + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ], + additional_kwargs={"name": "surfsense_system", "x": 1}, + response_metadata={"tokens": 42}, + ) + sys.id = "sys-msg-1" + request = _make_request(sys) + + flattened = _flattened_request(request) + assert flattened is not None + assert flattened.system_message.content == "a\n\nb" + assert flattened.system_message.additional_kwargs == { + "name": "surfsense_system", + "x": 1, + } + assert flattened.system_message.response_metadata == {"tokens": 42} + assert flattened.system_message.id == "sys-msg-1" + + def test_idempotent_when_run_twice(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + ) + request = _make_request(sys) + first = _flattened_request(request) + assert first is not None + + # Second pass on the already-flattened request should be a no-op. + # We re-wrap in a request stub since the helper inspects + # ``request.system_message.content``. + second_request = _make_request(first.system_message) + assert _flattened_request(second_request) is None + + +# --------------------------------------------------------------------------- +# Middleware integration — verify the handler sees a flattened request. +# --------------------------------------------------------------------------- + + +class TestMiddlewareWrap: + @pytest.mark.asyncio + async def test_async_passes_flattened_request_to_handler(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "text", "text": "beta"}, + ] + ) + request = _make_request(sys) + captured: dict[str, Any] = {} + + async def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + result = await mw.awrap_model_call(request, handler) + + assert result == "ok" + assert isinstance(captured["request"].system_message, SystemMessage) + assert captured["request"].system_message.content == "alpha\n\nbeta" + + @pytest.mark.asyncio + async def test_async_passes_through_when_already_string(self) -> None: + sys = SystemMessage(content="just a string") + request = _make_request(sys) + captured: dict[str, Any] = {} + + async def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + await mw.awrap_model_call(request, handler) + + # Same request object: no override happened. + assert captured["request"] is request + + def test_sync_passes_flattened_request_to_handler(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "text", "text": "beta"}, + ] + ) + request = _make_request(sys) + captured: dict[str, Any] = {} + + def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + result = mw.wrap_model_call(request, handler) + + assert result == "ok" + assert captured["request"].system_message.content == "alpha\n\nbeta" + + def test_sync_passes_through_when_no_system_message(self) -> None: + request = _make_request(None) + captured: dict[str, Any] = {} + + def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + mw.wrap_model_call(request, handler) + assert captured["request"] is request + + +# --------------------------------------------------------------------------- +# Regression guard — pin the worst-case shape that triggered the +# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the +# downstream cache_control_injection_points can only place 1 breakpoint +# on the system message regardless of provider redistribution quirks. +# --------------------------------------------------------------------------- + + +def test_regression_five_block_system_collapses_to_one_block() -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + ) + request = _make_request(sys) + flattened = _flattened_request(request) + + assert flattened is not None + assert isinstance(flattened.system_message.content, str) + # The exact join doesn't matter for the cache_control accounting — + # only that there is exactly ONE content block when LiteLLM's + # AnthropicCacheControlHook later targets ``role: system``. + assert " None: + # Sanity: the middleware MUST NOT touch user messages — only the + # system message. Multi-block user content is the path that carries + # image attachments and would lose its image_url block on + # accidental flatten. + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + ) + user = HumanMessage( + content=[ + {"type": "text", "text": "look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, + ] + ) + request = _make_request(sys) + request.messages = [user] + + flattened = _flattened_request(request) + assert flattened is not None + # System flattened to string … + assert isinstance(flattened.system_message.content, str) + # … user message is untouched (the helper does not even look at it). + assert flattened.messages == [user] + assert isinstance(user.content, list) + assert len(user.content) == 2 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py index 0bbdf37bf..d0ea73376 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py @@ -27,6 +27,7 @@ class TestDefaultAutoApprovedToolsList: expected = { "create_gmail_draft", "update_gmail_draft", + "create_calendar_event", "create_notion_page", "create_confluence_page", "create_google_drive_file", @@ -41,13 +42,12 @@ class TestDefaultAutoApprovedToolsList: assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset) def test_send_tools_are_not_auto_approved(self) -> None: - # External-broadcast tools must always prompt. + # External-broadcast / destructive tools must always prompt. for tool_name in ( "send_gmail_email", "send_discord_message", "send_teams_message", "delete_notion_page", - "create_calendar_event", "delete_calendar_event", ): assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, ( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py new file mode 100644 index 000000000..4cf53969d --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py @@ -0,0 +1,370 @@ +r"""Tests for ``apply_litellm_prompt_caching`` in +:mod:`app.agents.new_chat.prompt_caching`. + +The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which +never activated for our LiteLLM stack) with LiteLLM-native multi-provider +prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to +``litellm.completion(...)``. The tests below pin its public contract: + +1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so + savings compound across multi-turn conversations on Anthropic-family + providers. ``index: 0`` is used (rather than ``role: system``) because + the deepagent stack accumulates multiple ``SystemMessage``\ s in + ``state["messages"]`` and ``role: system`` would tag every one of + them, blowing past Anthropic's 4-block ``cache_control`` cap. +2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for + single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic + prompt-cache surface is available). +3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no + OpenAI-only kwargs because the router fans out across providers. +4. Idempotent: user-supplied values in ``model_kwargs`` are preserved. +5. Defensive: LLMs without a writable ``model_kwargs`` are silently + skipped rather than raising. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.agents.new_chat.llm_config import AgentConfig +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class _FakeLLM: + """Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``. + + The helper only inspects ``getattr(llm, "model_kwargs", None)``, + ``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple + object suffices — we don't need to spin up real LangChain/LiteLLM + machinery for unit tests of the helper's logic. + """ + + def __init__( + self, + model: str = "openai/gpt-4o", + model_kwargs: dict[str, Any] | None = None, + ) -> None: + self.model = model + self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {} + + +class ChatLiteLLMRouter: + """Class-name-only impostor of the real router. + + The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"`` + (a deliberate stringly-typed check to avoid an import cycle with + ``app.services.llm_router_service``). Reusing the same class name here + triggers the same code path without instantiating a real ``Router``. + """ + + def __init__(self) -> None: + self.model = "auto" + self.model_kwargs: dict[str, Any] = {} + + +def _make_cfg(**overrides: Any) -> AgentConfig: + """Build an ``AgentConfig`` with sensible defaults for the helper test.""" + defaults: dict[str, Any] = { + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "k", + } + return AgentConfig(**{**defaults, **overrides}) + + +# --------------------------------------------------------------------------- +# (a) Universal injection points +# --------------------------------------------------------------------------- + + +def test_sets_both_cache_control_injection_points_with_no_config() -> None: + """Bare call (no agent_config, no thread_id) still sets the two + universal breakpoints — these cost nothing on providers that don't + consume them and unlock caching on every supported provider.""" + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm) + + points = llm.model_kwargs["cache_control_injection_points"] + assert {"location": "message", "index": 0} in points + assert {"location": "message", "index": -1} in points + assert len(points) == 2 + + +def test_does_not_inject_role_system_breakpoint() -> None: + """Regression: deliberately AVOID ``role: system`` so we don't tag + every SystemMessage the deepagent ``before_agent`` injectors push + into ``state["messages"]`` (priority, tree, memory, file-intent, + anonymous-doc). Tagging all of them overflows Anthropic's 4-block + ``cache_control`` cap and surfaces as + ``OpenrouterException: A maximum of 4 blocks with cache_control may + be provided. Found N`` 400s. + """ + llm = _FakeLLM() + apply_litellm_prompt_caching(llm) + points = llm.model_kwargs["cache_control_injection_points"] + assert all(p.get("role") != "system" for p in points), ( + f"Expected no role=system breakpoint, got: {points}" + ) + + +def test_injection_points_set_for_anthropic_config() -> None: + """Anthropic-family configs need the marker — verify it lands.""" + cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet") + llm = _FakeLLM(model="anthropic/claude-3-5-sonnet") + + apply_litellm_prompt_caching(llm, agent_config=cfg) + + assert "cache_control_injection_points" in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (b) Idempotency / user override wins +# --------------------------------------------------------------------------- + + +def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None: + """Users who set their own injection points (e.g. with ``ttl: "1h"`` + via ``litellm_params``) keep them — the helper merges, never + clobbers.""" + user_points = [ + {"location": "message", "role": "system", "ttl": "1h"}, + ] + llm = _FakeLLM( + model_kwargs={"cache_control_injection_points": user_points}, + ) + + apply_litellm_prompt_caching(llm) + + assert llm.model_kwargs["cache_control_injection_points"] is user_points + + +def test_idempotent_when_called_multiple_times() -> None: + """Build-time + thread-time double-call must be a no-op the second time.""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1) + snapshot = { + "cache_control_injection_points": list( + llm.model_kwargs["cache_control_injection_points"] + ), + "prompt_cache_key": llm.model_kwargs["prompt_cache_key"], + "prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"], + } + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1) + + assert ( + llm.model_kwargs["cache_control_injection_points"] + == snapshot["cache_control_injection_points"] + ) + assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"] + assert ( + llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"] + ) + + +def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None: + """A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via + ``litellm_params``) wins over our default per-thread key.""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"}) + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc" + + +# --------------------------------------------------------------------------- +# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"]) +def test_sets_openai_family_extras(provider: str) -> None: + """OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate + via routing affinity) and ``prompt_cache_retention="24h"`` (extends + cache TTL beyond the default 5-10 min).""" + cfg = _make_cfg(provider=provider) + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42" + assert llm.model_kwargs["prompt_cache_retention"] == "24h" + + +def test_skips_prompt_cache_key_when_no_thread_id() -> None: + """Without a thread id we can't construct a per-thread key. Retention + is still useful so we set it (it's free).""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None) + + assert "prompt_cache_key" not in llm.model_kwargs + assert llm.model_kwargs["prompt_cache_retention"] == "24h" + + +@pytest.mark.parametrize( + "provider", + ["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"], +) +def test_no_openai_extras_for_other_providers(provider: str) -> None: + """Non-OpenAI-family providers don't expose ``prompt_cache_key`` — + skip it. ``cache_control_injection_points`` is still set (universal).""" + cfg = _make_cfg(provider=provider) + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_no_openai_extras_in_auto_mode() -> None: + """Auto-mode fans out across mixed providers — we can't statically + target OpenAI-only kwargs.""" + cfg = AgentConfig.from_auto_mode() + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_no_openai_extras_for_custom_provider() -> None: + """Custom providers route through arbitrary user-supplied prefixes — + we don't try to infer OpenAI-family compatibility.""" + cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (d) ChatLiteLLMRouter — universal injection points only +# --------------------------------------------------------------------------- + + +def test_router_llm_gets_only_universal_injection_points() -> None: + """Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must + receive only the universal injection points — its requests dispatch + across provider deployments and OpenAI-only kwargs would be wasted + (or stripped by ``drop_params``) on non-OpenAI legs.""" + router = ChatLiteLLMRouter() + cfg = _make_cfg(provider="OPENAI") + + apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42) + + assert "cache_control_injection_points" in router.model_kwargs + assert "prompt_cache_key" not in router.model_kwargs + assert "prompt_cache_retention" not in router.model_kwargs + + +# --------------------------------------------------------------------------- +# (e) Defensive paths +# --------------------------------------------------------------------------- + + +def test_handles_llm_with_no_writable_model_kwargs() -> None: + """Some LLM implementations (e.g. fakes / minimal subclasses) don't + expose a writable ``model_kwargs``. The helper must skip silently — + raising would crash the entire LLM build path on a non-critical + optimisation.""" + + class _ImmutableLLM: + # ``__slots__`` blocks attribute creation, so ``setattr`` raises. + __slots__ = ("model",) + + def __init__(self) -> None: + self.model = "openai/gpt-4o" + + llm = _ImmutableLLM() + + apply_litellm_prompt_caching(llm) + + +def test_initialises_missing_model_kwargs_dict() -> None: + """When ``model_kwargs`` is present-but-None (Pydantic v2 default + pattern when no factory is set), the helper initialises it to an + empty dict before mutating.""" + + class _LazyLLM: + def __init__(self) -> None: + self.model = "openai/gpt-4o" + self.model_kwargs: dict[str, Any] | None = None + + llm = _LazyLLM() + + apply_litellm_prompt_caching(llm) + + assert isinstance(llm.model_kwargs, dict) + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None: + """Direct caller path (e.g. ``create_chat_litellm_from_config`` for + YAML configs without a structured ``AgentConfig``): without + ``agent_config`` the helper sets only the universal injection points + — no OpenAI-family extras even if the prefix says ``openai/``. + Conservative: we'd rather miss the speedup than silently misroute.""" + llm = _FakeLLM(model="openai/gpt-4o") + + apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99) + + assert "cache_control_injection_points" in llm.model_kwargs + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (f) drop_params safety net (regression guard for #19346) +# --------------------------------------------------------------------------- + + +def test_litellm_drop_params_is_globally_enabled() -> None: + """``litellm.drop_params=True`` is set globally in + :mod:`app.services.llm_service` so any ``prompt_cache_key`` / + ``prompt_cache_retention`` we set on an OpenAI-family config is + auto-stripped if the request later routes to a non-supporting + provider (e.g. via auto-mode router fallback). This test pins that + invariant — losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``. + """ + import litellm + + import app.services.llm_service # noqa: F401 (side-effect: sets globals) + + assert litellm.drop_params is True + + +# --------------------------------------------------------------------------- +# Regression note: LiteLLM #15696 (multi-content-block last message) +# --------------------------------------------------------------------------- +# +# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]`` +# would get ``cache_control`` applied to *every* content block instead +# of only the last one — wasting cache breakpoints and triggering 400s +# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in +# https://github.com/BerriAI/litellm/pull/15699. +# +# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix). +# An end-to-end behavioural test would need to run ``litellm.completion`` +# through the Anthropic transformer, which is integration territory and +# better covered by LiteLLM's own test suite. The unit guard here is the +# version pin plus the build-time ``model_kwargs`` shape we verify above. diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py new file mode 100644 index 000000000..ffe3dbaa4 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py @@ -0,0 +1,117 @@ +"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`. + +The helper picks the model id fed to ``detect_provider_variant`` so the +right ```` block lands in the system prompt. The tests +below pin its preference order: + +1. ``agent_config.litellm_params["base_model"]`` (Azure-correct). +2. ``agent_config.model_name``. +3. ``getattr(llm, "model", None)``. + +Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would +silently miss every provider regex. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name +from app.agents.new_chat.llm_config import AgentConfig + +pytestmark = pytest.mark.unit + + +def _make_cfg(**overrides) -> AgentConfig: + """Build an ``AgentConfig`` with sensible defaults for the helper test.""" + defaults = { + "provider": "OPENAI", + "model_name": "x", + "api_key": "k", + } + return AgentConfig(**{**defaults, **overrides}) + + +class _FakeLLM: + """Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance. + + The resolver only reads the ``.model`` attribute via ``getattr``, + matching the established idiom in ``knowledge_search.py`` / + ``stream_new_chat.py`` / ``document_summarizer.py``. + """ + + def __init__(self, model: str | None) -> None: + self.model = model + + +def test_prefers_litellm_params_base_model_over_deployment_name() -> None: + """Azure deployment slug must NOT shadow the underlying model family. + + This is the failure mode the helper exists to prevent: a deployment + named ``"azure/prod-chat-001"`` would not match any provider regex + on its own, but the family ``"gpt-4o"`` lives in + ``litellm_params["base_model"]`` and routes to ``openai_classic``. + """ + cfg = _make_cfg( + model_name="azure/prod-chat-001", + litellm_params={"base_model": "gpt-4o"}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o" + + +def test_falls_back_to_model_name_when_litellm_params_is_none() -> None: + cfg = _make_cfg( + model_name="anthropic/claude-3-5-sonnet", + litellm_params=None, + ) + got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet")) + assert got == "anthropic/claude-3-5-sonnet" + + +def test_handles_litellm_params_without_base_model_key() -> None: + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"temperature": 0.5}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_ignores_blank_base_model() -> None: + """Whitespace-only ``base_model`` must not shadow ``model_name``.""" + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"base_model": " "}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_ignores_non_string_base_model() -> None: + """Defensive: a non-string ``base_model`` should not crash the resolver.""" + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"base_model": 42}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_falls_back_to_llm_model_when_no_agent_config() -> None: + """No ``agent_config`` -> use ``llm.model`` directly. Defensive path + for direct callers; production callers always supply a config.""" + assert ( + _resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini")) + == "openai/gpt-4o-mini" + ) + + +def test_returns_none_when_nothing_available() -> None: + """``compose_system_prompt`` treats ``None`` as the ``"default"`` + variant and emits no provider block.""" + assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None + + +def test_auto_mode_resolves_to_auto_string() -> None: + """Auto mode -> ``"auto"``. ``detect_provider_variant("auto")`` + returns ``"default"``, which is correct: the child model isn't + known until the LiteLLM Router dispatches.""" + cfg = AgentConfig.from_auto_mode() + assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto" diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 2ca470680..2933a0504 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -475,3 +475,190 @@ class TestKBSearchPlanSchema: ) ) assert plan.is_recency_query is False + + +# ── mentioned_document_ids cross-turn drain ──────────────────────────── + + +class TestKnowledgePriorityMentionDrain: + """Regression tests for the cross-turn ``mentioned_document_ids`` drain. + + The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware` + instance across turns of the same thread. ``mentioned_document_ids`` + can therefore enter the middleware via two paths: + + 1. The constructor closure (``__init__(mentioned_document_ids=...)``) — + seeded by the cache-miss build on turn 1. + 2. ``runtime.context.mentioned_document_ids`` — supplied freshly per + turn by the streaming task. + + Without the drain fix, an empty ``runtime.context.mentioned_document_ids`` + on turn 2 would fall through to the closure (because ``[]`` is falsy in + Python) and replay turn 1's mentions. This class pins down the + correct behaviour: the runtime path is authoritative even when empty, + and the closure is drained the first time the runtime path fires so + no later turn can ever resurrect stale state. + """ + + @staticmethod + def _make_runtime(mention_ids: list[int]): + """Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``.""" + from types import SimpleNamespace + + return SimpleNamespace( + context=SimpleNamespace(mentioned_document_ids=mention_ids), + ) + + @staticmethod + def _planner_llm() -> "FakeLLM": + # Planner returns a stable, non-recency plan so we always land in + # the hybrid-search branch (where ``fetch_mentioned_documents`` is + # invoked alongside the main search). + return FakeLLM( + json.dumps( + { + "optimized_query": "follow up question", + "start_date": None, + "end_date": None, + "is_recency_query": False, + } + ) + ) + + async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch): + """Turn 1 with mentions in BOTH closure and runtime context: the + runtime path wins AND the closure is drained so a future turn + cannot replay it. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[1, 2, 3], + ) + + await middleware.abefore_agent( + {"messages": [HumanMessage(content="what is in those docs?")]}, + runtime=self._make_runtime([1, 2, 3]), + ) + + assert fetched_ids == [[1, 2, 3]], ( + "runtime.context mentions must be the source of truth on turn 1" + ) + assert middleware.mentioned_document_ids == [], ( + "closure must be drained the first time the runtime path fires " + "so no later turn can replay stale mentions" + ) + + async def test_empty_runtime_context_does_not_replay_closure_mentions( + self, monkeypatch + ): + """Regression: turn 2 with NO mentions must not surface turn 1's + mentions from the constructor closure. + + Before the fix, ``if ctx_mentions:`` treated an empty list as + absent and fell through to ``elif self.mentioned_document_ids:``, + replaying turn 1's mentions. This test pins down the corrected + behaviour. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + # Simulate a cached middleware instance whose closure was seeded + # by a previous turn's cache-miss build (mentions=[1,2,3]). + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[1, 2, 3], + ) + + # Turn 2: streaming task supplies an EMPTY mention list (no + # mentions on this follow-up turn). + await middleware.abefore_agent( + {"messages": [HumanMessage(content="what about the next steps?")]}, + runtime=self._make_runtime([]), + ) + + assert fetched_ids == [], ( + "fetch_mentioned_documents must NOT be called when the runtime " + "context says there are no mentions for this turn" + ) + + async def test_legacy_path_fires_only_when_runtime_context_absent( + self, monkeypatch + ): + """Backward-compat: if a caller doesn't supply runtime.context (old + non-streaming code path), the closure-injected mentions are still + honoured exactly once and then drained. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[7, 8], + ) + + # First call: no runtime → legacy path uses the closure. + await middleware.abefore_agent( + {"messages": [HumanMessage(content="initial question")]}, + runtime=None, + ) + # Second call: still no runtime — closure already drained, so no replay. + await middleware.abefore_agent( + {"messages": [HumanMessage(content="follow up")]}, + runtime=None, + ) + + assert fetched_ids == [[7, 8]], ( + "legacy path must honour the closure exactly once and then drain it" + ) + assert middleware.mentioned_document_ids == [] diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py new file mode 100644 index 000000000..c9f18d77d --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py @@ -0,0 +1,110 @@ +"""Unit tests for ``supports_image_input`` derivation on BYOK chat config +endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). + +There is no DB column for ``supports_image_input`` on +``NewLLMConfig`` — the value is resolved at the API boundary by +``derive_supports_image_input`` so the new-chat selector / streaming +task can read the same field shape regardless of source (BYOK vs YAML +vs OpenRouter dynamic). Default-allow on unknown so we don't lock the +user out of their own model choice. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from app.db import LiteLLMProvider +from app.routes import new_llm_config_routes + +pytestmark = pytest.mark.unit + + +def _byok_row( + *, + id_: int, + model_name: str, + base_model: str | None = None, + provider: LiteLLMProvider = LiteLLMProvider.OPENAI, + custom_provider: str | None = None, +) -> object: + """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` + walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. + + ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's + enum validator accepts it — same as the ORM row would carry.""" + return SimpleNamespace( + id=id_, + name=f"BYOK-{id_}", + description=None, + provider=provider, + custom_provider=custom_provider, + model_name=model_name, + api_key="sk-byok", + api_base=None, + litellm_params={"base_model": base_model} if base_model else None, + system_instructions="", + use_default_system_instructions=True, + citations_enabled=True, + created_at=datetime.now(tz=UTC), + search_space_id=42, + user_id=uuid4(), + ) + + +def test_serialize_byok_known_vision_model_resolves_true(): + """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> + True. The serialized row carries that value through to the + ``NewLLMConfigRead`` schema.""" + row = _byok_row(id_=1, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + assert serialized.id == 1 + assert serialized.model_name == "gpt-4o" + + +def test_serialize_byok_unknown_model_default_allows(): + """Unknown / unmapped: default-allow. The streaming-task safety net + is the actual block, and it requires LiteLLM to *explicitly* say + text-only — so a brand new BYOK model should not be pre-judged.""" + row = _byok_row( + id_=2, + model_name="brand-new-model-x9-unmapped", + provider=LiteLLMProvider.CUSTOM, + custom_provider="brand_new_proxy", + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_uses_base_model_when_present(): + """Azure-style: ``model_name`` is the deployment id, ``base_model`` + inside ``litellm_params`` is the canonical sku LiteLLM knows. The + helper must consult ``base_model`` first or unrecognised deployment + ids would shadow the real capability.""" + row = _byok_row( + id_=3, + model_name="my-azure-deployment-id-no-litellm-knows-this", + base_model="gpt-4o", + provider=LiteLLMProvider.AZURE_OPENAI, + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_returns_pydantic_read_model(): + """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so + the schema additions are guaranteed to be present in the API + surface. This guards against a future regression where someone + deletes the augmentation step and falls back to ORM passthrough.""" + from app.schemas import NewLLMConfigRead + + row = _byok_row(id_=4, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py new file mode 100644 index 000000000..2b6c76485 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -0,0 +1,184 @@ +"""Unit tests for ``is_premium`` derivation on the global image-gen and +vision-LLM list endpoints. + +Chat globals (``GET /global-llm-configs``) already emit +``is_premium = (billing_tier == "premium")``. Image and vision did not, +which made the new-chat ``model-selector`` render the Free/Premium badge +on the Chat tab but skip it on the Image and Vision tabs (the selector +keys its badge logic off ``is_premium``). These tests pin parity: + +* YAML free entry → ``is_premium=False`` +* YAML premium entry → ``is_premium=True`` +* OpenRouter dynamic premium entry → ``is_premium=True`` +* Auto stub (always emitted when at least one config is present) + → ``is_premium=False`` +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_IMAGE_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "DALL-E 3", + "provider": "OPENAI", + "model_name": "dall-e-3", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "GPT-Image 1 (premium)", + "provider": "OPENAI", + "model_name": "gpt-image-1", + "api_key": "sk-test", + "billing_tier": "premium", + }, + { + "id": -20_001, + "name": "google/gemini-2.5-flash-image (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +_VISION_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o Vision", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "Claude 3.5 Sonnet (premium)", + "provider": "ANTHROPIC", + "model_name": "claude-3-5-sonnet", + "api_key": "sk-ant-test", + "billing_tier": "premium", + }, + { + "id": -30_001, + "name": "openai/gpt-4o (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +# ============================================================================= +# Image generation +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_emit_is_premium(monkeypatch): + """Each emitted config must carry ``is_premium`` derived server-side + from ``billing_tier``. The Auto stub is always free. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False + ) + + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + # Auto stub is always emitted when at least one global config exists, + # and it must always declare itself free (Auto-mode billing-tier + # surfacing is a separate follow-up). + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + # YAML free entry — ``is_premium=False`` + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + # YAML premium entry — ``is_premium=True`` + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + # OpenRouter dynamic premium entry — same field, same derivation + assert by_id[-20_001]["is_premium"] is True + assert by_id[-20_001]["billing_tier"] == "premium" + + # Every emitted dict (including Auto) must have the field — never missing. + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): + """When there are no global configs at all, the endpoint emits an + empty list (no Auto stub) — Auto mode would have nothing to route to. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + assert payload == [] + + +# ============================================================================= +# Vision LLM +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr( + config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False + ) + + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + assert by_id[-30_001]["is_premium"] is True + assert by_id[-30_001]["billing_tier"] == "premium" + + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py new file mode 100644 index 000000000..b47d9134b --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -0,0 +1,106 @@ +"""Unit tests for ``supports_image_input`` derivation on the chat global +config endpoint (``GET /global-new-llm-configs``). + +Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): + +1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML + loader for operator overrides, or by the OpenRouter integration from + ``architecture.input_modalities``) — wins. +2. ``derive_supports_image_input`` helper — default-allow on unknown + models, only False when LiteLLM / OR modalities are definitive. + +The flag is purely informational at the API boundary. The streaming +task safety net (``is_known_text_only_chat_model``) is the actual block, +and it requires LiteLLM to *explicitly* mark the model as text-only. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o (explicit true)", + "description": "vision-capable, explicit YAML override", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + "supports_image_input": True, + }, + { + "id": -2, + "name": "DeepSeek V3 (explicit false)", + "description": "OpenRouter dynamic — modality-derived false", + "provider": "OPENROUTER", + "model_name": "deepseek/deepseek-v3.2-exp", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "free", + "supports_image_input": False, + }, + { + "id": -10_010, + "name": "Unannotated GPT-4o", + "description": "no flag set — resolver should derive True via LiteLLM", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + # supports_image_input intentionally absent + }, + { + "id": -10_011, + "name": "Unannotated unknown model", + "description": "unmapped — default-allow True", + "provider": "CUSTOM", + "custom_provider": "brand_new_proxy", + "model_name": "brand-new-model-x9", + "api_key": "sk-test", + "billing_tier": "free", + }, +] + + +@pytest.mark.asyncio +async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): + """Each emitted chat config carries ``supports_image_input`` as a + bool. Explicit values win; unannotated entries are resolved via the + helper (default-allow True).""" + from app.config import config + from app.routes import new_llm_config_routes + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) + + payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) + by_id = {c["id"]: c for c in payload} + + # Auto stub: optimistic True so the user can keep Auto selected with + # vision-capable deployments somewhere in the pool. + assert 0 in by_id, "Auto stub should be emitted when configs exist" + assert by_id[0]["supports_image_input"] is True + assert by_id[0]["is_auto_mode"] is True + + # Explicit True is preserved. + assert by_id[-1]["supports_image_input"] is True + + # Explicit False is preserved (the exact failure mode the safety net + # guards against — DeepSeek V3 over OpenRouter would 404 with "No + # endpoints found that support image input"). + assert by_id[-2]["supports_image_input"] is False + + # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. + assert by_id[-10_010]["supports_image_input"] is True + + # Unknown / unmapped model: default-allow rather than pre-judge. + assert by_id[-10_011]["supports_image_input"] is True + + for cfg in payload: + assert "supports_image_input" in cfg, ( + f"supports_image_input missing from {cfg.get('id')}" + ) + assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py new file mode 100644 index 000000000..636b7de31 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -0,0 +1,138 @@ +"""Unit tests for the image-generation route's billing-resolution helper. + +End-to-end "POST /image-generations returns 402" coverage requires the +integration harness (real DB, real auth) and lives in +``tests/integration/document_upload/`` alongside the other quota tests. +This unit test focuses on the new ``_resolve_billing_for_image_gen`` +helper which: + +* Returns ``free`` for Auto mode, even when premium configs exist + (Auto-mode billing-tier surfacing is a follow-up). +* Returns ``free`` for user-owned BYOK configs (positive IDs). +* Returns the global config's ``billing_tier`` for negative IDs. +* Honours the per-config ``quota_reserve_micros`` override when present. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_resolve_billing_for_auto_mode(monkeypatch): + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, # Not consumed on this code path. + config_id=0, # IMAGE_GEN_AUTO_MODE_ID + search_space=search_space, + ) + assert tier == "free" + assert model == "auto" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_premium_global_config(monkeypatch): + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + "quota_reserve_micros": 75_000, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "billing_tier": "free", + }, + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=None) + + # Premium with override. + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-1, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" + assert reserve == 75_000 + + # Free, no override → falls back to default. + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-2, search_space=search_space + ) + assert tier == "free" + # Provider-prefixed model string for OpenRouter. + assert "google/gemini-2.5-flash-image" in model + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_user_owned_byok_is_free(): + """User-owned BYOK configs (positive IDs) cost the user nothing on + our side — they pay the provider directly. Always free. + """ + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=42, search_space=search_space + ) + assert tier == "free" + assert model == "user_byok" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): + """When the request omits ``image_generation_config_id``, the helper + must consult the search space's default — so a search space pinned + to a premium global config still gates new requests by quota. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -7, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + } + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=-7) + ( + tier, + model, + _reserve, + ) = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=None, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py new file mode 100644 index 000000000..fa8819b39 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -0,0 +1,436 @@ +"""Unit tests for ``_resolve_agent_billing_for_search_space``. + +Validates the resolver used by Celery podcast/video tasks to compute +``(owner_user_id, billing_tier, base_model)`` from a search space and its +agent LLM config. The resolver mirrors chat's billing-resolution pattern at +``stream_new_chat.py:2294-2351`` and is the single integration point that +prevents Auto-mode podcast/video from leaking premium credit. + +Coverage: + +* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium + global → returns ``("premium", )``. +* Auto mode + ``thread_id`` set, pin resolves to a negative-id free + global → returns ``("free", )``. +* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config + → always ``"free"``. +* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without + hitting the pin service. +* Negative id (no Auto) → uses ``get_global_llm_config``'s + ``billing_tier``. +* Positive id (user BYOK) → always ``"free"``. +* Search space not found → raises ``ValueError``. +* ``agent_llm_id`` is None → raises ``ValueError``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from uuid import UUID, uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + +class _FakeSession: + """Tiny AsyncSession stub. + + ``responses`` is a list of objects to return from successive + ``execute()`` calls (in order). The resolver makes at most two + ``execute()`` calls (search-space lookup, then optionally NewLLMConfig + lookup), so two queued responses cover the matrix. + """ + + def __init__(self, responses: list): + self._responses = list(responses) + + async def execute(self, _stmt): + if not self._responses: + return _FakeExecResult(None) + return _FakeExecResult(self._responses.pop(0)) + + async def commit(self) -> None: + pass + + +@dataclass +class _FakePinResolution: + resolved_llm_config_id: int + resolved_tier: str = "premium" + from_existing_pin: bool = False + + +def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace( + id=42, + agent_llm_id=agent_llm_id, + user_id=user_id, + ) + + +def _make_byok_config( + *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +) -> SimpleNamespace: + return SimpleNamespace( + id=id_, + model_name=model_name, + litellm_params={"base_model": base_model} if base_model else {}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): + """Auto + thread → pin service resolves to negative-id premium config → + resolver returns ``("premium", )``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + # Mock the pin service to return a concrete premium config id. + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + assert selected_llm_config_id == 0 + assert thread_id == 99 + return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") + + # Mock global config lookup to return a premium entry. + def _fake_get_global(cfg_id): + if cfg_id == -1: + return { + "id": -1, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + return None + + # Lazy imports inside the resolver — patch the *target* modules so the + # imported names resolve to our fakes. + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): + """Auto + thread → pin returns negative-id free config → resolver + returns ``("free", )``. Same path the pin service takes for + out-of-credit users (graceful degradation).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") + + def _fake_get_global(cfg_id): + if cfg_id == -3: + return { + "id": -3, + "model_name": "openrouter/free-model", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/free-model"}, + } + return None + + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/free-model" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): + """Auto + thread → pin returns positive-id BYOK config → resolver + returns ``("free", ...)`` (BYOK is always free per + ``AgentConfig.from_new_llm_config``).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=0, user_id=user_id) + byok_cfg = _make_byok_config( + id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + ) + session = _FakeSession([search_space, byok_cfg]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3-haiku" + + +@pytest.mark.asyncio +async def test_auto_mode_without_thread_id_falls_back_to_free(): + """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking + the pin service. Forward-compat fallback for any future direct-API + entrypoint that doesn't have a chat thread.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): + """If the pin service raises ``ValueError`` (thread missing / + mismatched search space), the resolver should log and return free + rather than killing the whole task.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin(*args, **kwargs): + raise ValueError("thread missing") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_negative_id_premium_global_returns_premium(monkeypatch): + """Explicit negative agent_llm_id → ``get_global_llm_config`` → + return its ``billing_tier``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_negative_id_free_global_returns_free(monkeypatch): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "openrouter/some-free", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/some-free"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/some-free" + + +@pytest.mark.asyncio +async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): + """When the global config has no ``litellm_params.base_model``, the + resolver falls back to ``model_name`` — matching chat's behavior.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "fallback-model", + "billing_tier": "premium", + # No litellm_params. + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + _, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert tier == "premium" + assert base_model == "fallback-model" + + +@pytest.mark.asyncio +async def test_positive_id_byok_is_always_free(): + """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, + regardless of underlying provider tier.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=23, user_id=user_id) + byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_cfg]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3.5-sonnet" + + +@pytest.mark.asyncio +async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): + """If the BYOK config row is missing/deleted but the search space still + points at it, the resolver still returns free (no debit) with an empty + base_model — billable_call's premium path is skipped, no harm done.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "" + + +@pytest.mark.asyncio +async def test_search_space_not_found_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + session = _FakeSession([None]) + + with pytest.raises(ValueError, match="Search space"): + await _resolve_agent_billing_for_search_space(session, search_space_id=999) + + +@pytest.mark.asyncio +async def test_agent_llm_id_none_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + + with pytest.raises(ValueError, match="agent_llm_id"): + await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py new file mode 100644 index 000000000..d1af29aeb --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -0,0 +1,1026 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + clear_healthy, + clear_runtime_cooldown, + is_recently_healthy, + mark_healthy, + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _clear_runtime_cooldown_map(): + clear_runtime_cooldown() + clear_healthy() + yield + clear_runtime_cooldown() + clear_healthy() + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread( + *, + search_space_id: int = 10, + pinned_llm_config_id: int | None = None, +): + return SimpleNamespace( + id=1, + search_space_id=search_space_id, + pinned_llm_config_id=pinned_llm_config_id, + ) + + +@pytest.mark.asyncio +async def test_auto_first_turn_pins_one_model(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + "quality_score": 100, + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + "quality_score": 10, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.1", + "api_key": "k1", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "api_key": "k2", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + }, + { + "id": -3, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5.4", + "api_key": "k3", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 100, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_next_turn_reuses_existing_pin(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError( + "premium_get_usage should not be called for valid pin reuse" + ) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_can_pin_premium(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_ineligible_auto_pins_free_only(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + + +@pytest.mark.asyncio +async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + force_repin_free=True, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_explicit_user_model_change_clears_pin(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-2)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=7, + ) + assert result.resolved_llm_config_id == 7 + assert session.thread.pinned_llm_config_id is None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-999)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert session.thread.pinned_llm_config_id == -2 + assert session.commit_count == 1 + + +# --------------------------------------------------------------------------- +# Quality-aware pin selection (Auto Fastest upgrade) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_health_gated_config_is_excluded_from_selection(monkeypatch): + """A cfg flagged ``health_gated`` must never be picked even if it has + the highest score among eligible cfgs.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 95, + "health_gated": True, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): + """Premium-eligible users with Tier A available should never spill to + Tier B even if a B cfg ranks higher by ``quality_score``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 70, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5", + "api_key": "k-or", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 95, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch): + """Free-only user with no Tier A free cfg should pick from Tier C.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash:free", + "api_key": "k-or", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_top_k_picks_only_high_score_models(monkeypatch): + """Different thread IDs should spread across top-K, never pick the + obvious low-quality cfg even when it sits in the candidate list.""" + from app.config import config + + high_score_cfgs = [ + { + "id": -i, + "provider": "AZURE_OPENAI", + "model_name": f"gpt-x-{i}", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + } + for i in range(1, 6) # 5 high-quality Tier A cfgs + ] + low_score_trap = { + "id": -99, + "provider": "AZURE_OPENAI", + "model_name": "tiny-legacy", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + "health_gated": False, + } + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [*high_score_cfgs, low_score_trap], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + high_score_ids = {c["id"] for c in high_score_cfgs} + seen = set() + for thread_id in range(1, 50): + session = _FakeSession(_thread()) + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + seen.add(result.resolved_llm_config_id) + assert result.resolved_llm_config_id != -99, ( + "low-score trap cfg should never be picked" + ) + assert result.resolved_llm_config_id in high_score_ids + + # Spread across at least a couple of top-K cfgs. + assert len(seen) > 1 + + +@pytest.mark.asyncio +async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): + """An *already* pinned cfg that later flips to ``health_gated`` should + still not be reused — gated cfgs are filtered out of the candidate + pool, which forces a repair to a healthy cfg. + + This guards the no-silent-tier-switch invariant: we don't keep using + a known-broken model just because the thread happened to be pinned + to it before the gate fired.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 50, + "health_gated": True, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): + """Existing pin reuse must short-circuit the new tier/score logic.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 50, # lower than -2 + "health_gated": False, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5-pro", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 99, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): + """A runtime-cooled config should be excluded from candidate reuse. + + This enables one-shot recovery from transient provider 429 bursts: we can + mark the pinned cfg as cooled down and force a repair to another eligible + cfg on the next resolution. + """ + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on healthy pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + clear_runtime_cooldown(-1) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch): + """Runtime retry should never repin the just-failed config.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + exclude_config_ids={-1}, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +# --------------------------------------------------------------------------- +# Healthy-status cache (preflight TTL companion) +# --------------------------------------------------------------------------- + + +def test_mark_healthy_then_is_recently_healthy_true_within_ttl(): + mark_healthy(-42, ttl_seconds=60) + assert is_recently_healthy(-42) is True + + +def test_healthy_expires_after_ttl(monkeypatch): + import app.services.auto_model_pin_service as svc + + real_time = svc.time.time + base = real_time() + + monkeypatch.setattr(svc.time, "time", lambda: base) + mark_healthy(-7, ttl_seconds=10) + assert is_recently_healthy(-7) is True + + monkeypatch.setattr(svc.time, "time", lambda: base + 11) + assert is_recently_healthy(-7) is False + + +def test_mark_runtime_cooldown_invalidates_healthy_cache(): + mark_healthy(-9, ttl_seconds=60) + assert is_recently_healthy(-9) is True + + mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60) + assert is_recently_healthy(-9) is False + + +def test_clear_healthy_removes_single_entry(): + mark_healthy(-11, ttl_seconds=60) + mark_healthy(-12, ttl_seconds=60) + clear_healthy(-11) + assert is_recently_healthy(-11) is False + assert is_recently_healthy(-12) is True + + +def test_clear_healthy_no_args_drops_all_entries(): + mark_healthy(-21, ttl_seconds=60) + mark_healthy(-22, ttl_seconds=60) + clear_healthy() + assert is_recently_healthy(-21) is False + assert is_recently_healthy(-22) is False diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py new file mode 100644 index 000000000..0e19b80e4 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -0,0 +1,286 @@ +"""Image-aware extension of the Auto-pin resolver. + +When the current chat turn carries an ``image_url`` block, the pin +resolver must: + +1. Filter the candidate pool to vision-capable cfgs so a freshly + selected pin can never be text-only. +2. Treat any existing pin whose capability is False as invalid (force + re-pin), even when it would otherwise be reused as the thread's + stable model. +3. Raise ``ValueError`` (mapped to the friendly + ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming + task) when no vision-capable cfg is available — instead of silently + pinning text-only and 404-ing at the provider. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + clear_healthy, + clear_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _reset_caches(): + clear_runtime_cooldown() + clear_healthy() + yield + clear_runtime_cooldown() + clear_healthy() + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread(*, pinned: int | None = None): + return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned) + + +def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"vision-{id_}", + "api_key": "k", + "billing_tier": tier, + "supports_image_input": True, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"text-{id_}", + "api_key": "k", + "billing_tier": tier, + # Higher quality than the vision cfgs — so a bug that ignores + # the image flag would surface as the resolver picking this one. + "supports_image_input": False, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +async def _premium_allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + +@pytest.mark.asyncio +async def test_image_turn_filters_out_text_only_candidates(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + # The thread should be pinned to the vision cfg even though the + # text-only cfg has a higher quality score. + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): + """An existing text-only pin must be invalidated when the next turn + requires image input. The non-image path would happily reuse it.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_reuses_existing_vision_pin(monkeypatch): + """If the thread is already pinned to a vision-capable cfg, reuse it + — same as the non-image path. Image-aware filtering must not force + spurious re-pins.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-2)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): + """The friendly-error path: no vision-capable cfg in the pool -> raise + ``ValueError`` whose message contains ``vision-capable`` so the + streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _text_only_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + with pytest.raises(ValueError, match="vision-capable"): + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + +@pytest.mark.asyncio +async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): + """Regression guard: the image flag must default False and not affect + a normal text-only turn — text-only cfgs remain selectable.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + + +@pytest.mark.asyncio +async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): + """A YAML cfg that omits ``supports_image_input`` falls through to + ``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o`` + that returns True, so the cfg should be a valid candidate.""" + from app.config import config + + session = _FakeSession(_thread()) + cfg_unannotated_vision = { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-4o", # known vision model in LiteLLM map + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "A", + "quality_score": 80, + # NOTE: no supports_image_input key + } + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + assert result.resolved_llm_config_id == -2 diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py new file mode 100644 index 000000000..c820724ed --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -0,0 +1,559 @@ +"""Unit tests for the ``billable_call`` async context manager. + +Covers the per-call premium-credit lifecycle for image generation and +vision LLM extraction: + +* Free configs bypass reserve/finalize but still write an audit row. +* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the + route layer). +* Successful premium calls reserve, yield the accumulator, then finalize + with the LiteLLM-reported actual cost — and write an audit row. +* Failed premium calls release the reservation so credit isn't leaked. +* All quota DB ops happen inside their OWN ``shielded_async_session``, + isolating them from the caller's transaction (issue A). +""" + +from __future__ import annotations + +import asyncio +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeQuotaResult: + def __init__( + self, + *, + allowed: bool, + used: int = 0, + limit: int = 5_000_000, + remaining: int = 5_000_000, + ) -> None: + self.allowed = allowed + self.used = used + self.limit = limit + self.remaining = remaining + + +class _FakeSession: + """Minimal AsyncSession stub — record commits for assertion.""" + + def __init__(self) -> None: + self.committed = False + self.added: list[Any] = [] + + def add(self, obj: Any) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.committed = True + + async def rollback(self) -> None: + pass + + async def close(self) -> None: + pass + + +@contextlib.asynccontextmanager +async def _fake_shielded_session(): + s = _FakeSession() + _SESSIONS_USED.append(s) + yield s + + +_SESSIONS_USED: list[_FakeSession] = [] + + +def _patch_isolation_layer( + monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None +): + """Wire fake reserve/finalize/release/session helpers.""" + _SESSIONS_USED.clear() + reserve_calls: list[dict[str, Any]] = [] + finalize_calls: list[dict[str, Any]] = [] + release_calls: list[dict[str, Any]] = [] + + async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros): + reserve_calls.append( + { + "user_id": user_id, + "reserve_micros": reserve_micros, + "request_id": request_id, + } + ) + return reserve_result + + async def _fake_finalize( + *, db_session, user_id, request_id, actual_micros, reserved_micros + ): + if finalize_exc is not None: + raise finalize_exc + finalize_calls.append( + { + "user_id": user_id, + "actual_micros": actual_micros, + "reserved_micros": reserved_micros, + } + ) + return finalize_result or _FakeQuotaResult(allowed=True) + + async def _fake_release(*, db_session, user_id, reserved_micros): + release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros}) + + record_calls: list[dict[str, Any]] = [] + + async def _fake_record(session, **kwargs): + record_calls.append(kwargs) + return object() + + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_reserve", + _fake_reserve, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_finalize", + _fake_finalize, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_release", + _fake_release, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.shielded_async_session", + _fake_shielded_session, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _fake_record, + raising=False, + ) + + return { + "reserve": reserve_calls, + "finalize": finalize_calls, + "release": release_calls, + "record": record_calls, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + ) as acc: + # Simulate a captured cost — the accumulator is fed by the LiteLLM + # callback in real life, here we add it manually. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + # Free still audits. + assert len(spies["record"]) == 1 + assert spies["record"][0]["usage_type"] == "image_generation" + assert spies["record"][0]["cost_micros"] == 37_000 + + +@pytest.mark.asyncio +async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=5_000_000, limit=5_000_000, remaining=0 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "image_generation" + assert err.used_micros == 5_000_000 + assert err.limit_micros == 5_000_000 + assert err.remaining_micros == 0 + # Reserve was attempted, but no finalize/release on a denied reserve + # — the reservation never actually held credit. + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + assert spies["release"] == [] + # Denied premium calls do NOT create an audit row (no work happened). + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_success_finalizes_with_actual_cost(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + # LiteLLM callback would normally fill this — simulate $0.04 image. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert spies["reserve"][0]["reserve_micros"] == 50_000 + assert len(spies["finalize"]) == 1 + assert spies["finalize"][0]["actual_micros"] == 40_000 + assert spies["finalize"][0]["reserved_micros"] == 50_000 + assert spies["release"] == [] + # And audit row written with the actual debited cost. + assert spies["record"][0]["cost_micros"] == 40_000 + # Each quota op opened its OWN session — proves session isolation. + assert len(_SESSIONS_USED) >= 3 + # Sessions used should each have committed (or be the audit one which commits). + for _s in _SESSIONS_USED: + # finalize/reserve happen via TokenQuotaService.* which we stub — + # they don't actually call commit on our fake session, but the + # audit session does. We just assert >=1 session committed. + pass + assert any(s.committed for s in _SESSIONS_USED) + + +@pytest.mark.asyncio +async def test_premium_failure_releases_reservation(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _ProviderError(Exception): + pass + + with pytest.raises(_ProviderError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + raise _ProviderError("OpenRouter 503") + + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + # Failure path: release the held reservation. + assert len(spies["release"]) == 1 + assert spies["release"][0]["reserved_micros"] == 50_000 + + +@pytest.mark.asyncio +async def test_premium_uses_estimator_when_no_micros_override(monkeypatch): + """When ``quota_reserve_micros_override`` is None we fall back to + ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``. + Vision LLM calls take this path (token-priced models). + """ + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + captured_estimator_calls: list[dict[str, Any]] = [] + + def _fake_estimate(*, base_model, quota_reserve_tokens): + captured_estimator_calls.append( + {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens} + ) + return 12_345 + + monkeypatch.setattr( + "app.services.billable_calls.estimate_call_reserve_micros", + _fake_estimate, + raising=False, + ) + + user_id = uuid4() + async with billable_call( + user_id=user_id, + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + usage_type="vision_extraction", + ): + pass + + assert captured_estimator_calls == [ + {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000} + ] + assert spies["reserve"][0]["reserve_micros"] == 12_345 + + +@pytest.mark.asyncio +async def test_premium_finalize_failure_propagates_and_releases(monkeypatch): + from app.services.billable_calls import BillingSettlementError, billable_call + + class _FinalizeError(RuntimeError): + pass + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult(allowed=True), + finalize_exc=_FinalizeError("db finalize failed"), + ) + user_id = uuid4() + + with pytest.raises(BillingSettlementError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["release"]) == 1 + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _HangingCommitSession(_FakeSession): + async def commit(self) -> None: + await asyncio.sleep(60) + + @contextlib.asynccontextmanager + async def _hanging_session_factory(): + s = _HangingCommitSession() + _SESSIONS_USED.append(s) + yield s + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + billable_session_factory=_hanging_session_factory, + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["finalize"]) == 1 + assert len(spies["record"]) == 1 + assert spies["release"] == [] + + +@pytest.mark.asyncio +async def test_free_audit_failure_is_best_effort(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + async def _failing_record(_session, **_kwargs): + raise RuntimeError("audit insert failed") + + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _failing_record, + raising=False, + ) + + async with billable_call( + user_id=uuid4(), + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + + +# --------------------------------------------------------------------------- +# Podcast / video-presentation usage_type coverage +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch): + """Free podcast configs must skip reserve/finalize but still emit a + ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we + have full audit coverage of free-tier agent runs.""" + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openrouter/some-free-model", + quota_reserve_micros_override=200_000, + usage_type="podcast_generation", + thread_id=99, + call_details={"podcast_id": 7, "title": "Test Podcast"}, + ) as acc: + # Two transcript LLM calls aggregated into one accumulator. + acc.add( + model="openrouter/some-free-model", + prompt_tokens=1500, + completion_tokens=8000, + total_tokens=9500, + cost_micros=0, + call_kind="chat", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + + assert len(spies["record"]) == 1 + row = spies["record"][0] + assert row["usage_type"] == "podcast_generation" + assert row["thread_id"] is None + assert row["search_space_id"] == 42 + assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + + +@pytest.mark.asyncio +async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): + """Premium video-presentation runs that hit a denied reservation must + raise ``QuotaInsufficientError`` *before* the graph runs and must not + emit an audit row (no work happened).""" + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="gpt-5.4", + quota_reserve_micros_override=1_000_000, + usage_type="video_presentation_generation", + thread_id=99, + call_details={"video_presentation_id": 12, "title": "Test Video"}, + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "video_presentation_generation" + assert err.remaining_micros == 500_000 + assert spies["reserve"][0]["reserve_micros"] == 1_000_000 + assert spies["finalize"] == [] + assert spies["release"] == [] + assert spies["record"] == [] diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py new file mode 100644 index 000000000..9d5fdb190 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -0,0 +1,177 @@ +"""Defense-in-depth: image-gen call sites must not let an empty +``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. + +The bug repro: an OpenRouter image-gen config ships +``api_base=""``. The pre-fix call site in +``image_generation_routes._execute_image_generation`` did +``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which +silently dropped the empty string. LiteLLM then fell back to +``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) +and OpenRouter's ``image_generation/transformation`` appended +``/chat/completions`` to it → 404 ``Resource not found``. + +This test pins the post-fix behaviour: with an empty ``api_base`` in +the config, the call site MUST set ``api_base`` to OpenRouter's public +URL instead of leaving it unset. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): + """The global-config branch (``config_id < 0``) of + ``_execute_image_generation`` must apply the resolver and pin + ``api_base`` to OpenRouter when the config ships an empty string. + """ + from app.routes import image_generation_routes + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", # the original bug shape + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) + + image_gen = MagicMock() + image_gen.image_generation_config_id = cfg["id"] + image_gen.prompt = "test" + image_gen.n = 1 + image_gen.quality = None + image_gen.size = None + image_gen.style = None + image_gen.response_format = None + image_gen.model = None + + search_space = MagicMock() + search_space.image_generation_config_id = cfg["id"] + session = MagicMock() + + with ( + patch.object( + image_generation_routes, + "_get_global_image_gen_config", + return_value=cfg, + ), + patch.object( + image_generation_routes, + "aimage_generation", + side_effect=fake_aimage_generation, + ), + ): + await image_generation_routes._execute_image_generation( + session=session, image_gen=image_gen, search_space=search_space + ) + + # The whole point of the fix: even with empty ``api_base`` in the + # config, we forward OpenRouter's public URL so the call doesn't + # inherit an Azure endpoint. + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +@pytest.mark.asyncio +async def test_generate_image_tool_global_sets_api_base_when_config_empty(): + """Same defense at the agent tool entry point — both surfaces share + the same OpenRouter config payloads.""" + from app.agents.new_chat.tools import generate_image as gi_module + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + response = MagicMock() + response.model_dump.return_value = { + "data": [{"url": "https://example.com/x.png"}] + } + response._hidden_params = {"model": "openrouter/openai/gpt-image-1"} + return response + + search_space = MagicMock() + search_space.id = 1 + search_space.image_generation_config_id = cfg["id"] + + session_cm = AsyncMock() + session = AsyncMock() + session_cm.__aenter__.return_value = session + + scalars = MagicMock() + scalars.first.return_value = search_space + exec_result = MagicMock() + exec_result.scalars.return_value = scalars + session.execute.return_value = exec_result + session.add = MagicMock() + session.commit = AsyncMock() + session.refresh = AsyncMock() + + # ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback. + async def _refresh(obj): + obj.id = 1 + + session.refresh.side_effect = _refresh + + with ( + patch.object(gi_module, "shielded_async_session", return_value=session_cm), + patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object( + gi_module, "aimage_generation", side_effect=fake_aimage_generation + ), + patch.object( + gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0 + ), + ): + tool = gi_module.create_generate_image_tool( + search_space_id=1, db_session=MagicMock() + ) + await tool.ainvoke({"prompt": "a cat", "n": 1}) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +def test_image_gen_router_deployment_sets_api_base_when_config_empty(): + """The Auto-mode router pool must also resolve ``api_base`` when an + OpenRouter config ships an empty string. The deployment dict is fed + straight to ``litellm.Router``, so a missing ``api_base`` would + leak the same way as the direct call sites. + """ + from app.services.image_gen_router_service import ImageGenRouterService + + deployment = ImageGenRouterService._config_to_deployment( + { + "model_name": "openai/gpt-image-1", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py new file mode 100644 index 000000000..c309ff881 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -0,0 +1,226 @@ +"""LLMRouterService pool-filter / rebuild tests. + +These tests focus on the *config plumbing* (which configs enter the router +pool, rebuild resets state correctly). They stub out the underlying +``litellm.Router`` so we don't need real API keys or network access. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from app.services.llm_router_service import LLMRouterService + +pytestmark = pytest.mark.unit + + +def _fake_yaml_config( + *, + id: int, + model_name: str, + billing_tier: str = "free", +) -> dict: + return { + "id": id, + "name": f"yaml-{id}", + "provider": "OPENAI", + "model_name": model_name, + "api_key": "sk-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 100, + "tpm": 100_000, + "litellm_params": {}, + } + + +def _fake_openrouter_config( + *, + id: int, + model_name: str, + billing_tier: str, + router_pool_eligible: bool | None = None, +) -> dict: + """Build a synthetic dynamic-OR config dict for router-pool tests. + + Defaults mirror Strategy 3: premium OR enters the pool, free OR stays + out. Callers can override ``router_pool_eligible`` to simulate legacy + configs or to regression-test the filter mechanics directly. + """ + if router_pool_eligible is None: + router_pool_eligible = billing_tier == "premium" + return { + "id": id, + "name": f"or-{id}", + "provider": "OPENROUTER", + "model_name": model_name, + "api_key": "sk-or-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 20 if billing_tier == "free" else 200, + "tpm": 100_000 if billing_tier == "free" else 1_000_000, + "litellm_params": {}, + "router_pool_eligible": router_pool_eligible, + } + + +def _reset_router_singleton() -> None: + instance = LLMRouterService.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + + +def test_router_pool_includes_or_premium_excludes_or_free(): + """Strategy 3: premium OR joins the pool, free OR stays out. + + Dynamic OpenRouter premium entries opt into load balancing alongside + curated YAML configs. Dynamic OR free entries are intentionally kept + out because OpenRouter's free tier enforces a single account-global + quota bucket that per-deployment router accounting can't represent. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ), + _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ), + ] + + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + # YAML premium + YAML free + dynamic OR premium are all in the pool. + # Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced). + assert pool_models == { + "openai/gpt-4o", + "openai/gpt-4o-mini", + "openrouter/openai/gpt-4o", + } + + prem = LLMRouterService.get_instance()._premium_model_strings + # YAML premium is fingerprinted under both its model_string and its + # ``base_model`` form (existing behavior we don't want to regress). + assert "openai/gpt-4o" in prem + # Dynamic OR premium is now fingerprinted as premium so pool-level + # calls through the router are billed against premium quota. + assert "openrouter/openai/gpt-4o" in prem + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True + # Dynamic OR free never enters the pool, so it's never counted as premium. + assert ( + LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free") + is False + ) + + +def test_router_pool_filter_mechanics_respect_override(): + """The ``router_pool_eligible`` filter itself works independently of tier. + + Regression guard: if a future refactor ever sets the flag False on a + premium config (e.g. for maintenance), that config MUST be skipped by + ``initialize`` even though its tier is premium. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_openrouter_config( + id=-10_001, + model_name="openai/gpt-4o", + billing_tier="premium", + router_pool_eligible=False, # opt out despite being premium + ), + ] + + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + assert pool_models == {"openai/gpt-4o"} + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False + + +def test_rebuild_refreshes_pool_after_configs_change(): + _reset_router_singleton() + configs_v1 = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + ] + configs_v2 = [ + *configs_v1, + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + ] + + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + + LLMRouterService.initialize(configs_v1) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``initialize`` should be a no-op here (already initialized). + LLMRouterService.initialize(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``rebuild`` must clear the guard and re-run with the new configs. + LLMRouterService.rebuild(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 2 + + +def test_auto_model_pin_candidates_include_dynamic_openrouter(): + """Dynamic OR configs must remain Auto-mode thread-pin candidates. + + Guards against a future regression where someone adds the + ``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``. + """ + from app.config import config + from app.services.auto_model_pin_service import _global_candidates + + or_premium = _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ) + or_free = _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ) + original = config.GLOBAL_LLM_CONFIGS + try: + config.GLOBAL_LLM_CONFIGS = [or_premium, or_free] + candidate_ids = {c["id"] for c in _global_candidates()} + assert candidate_ids == {-10_001, -10_002} + finally: + config.GLOBAL_LLM_CONFIGS = original diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py new file mode 100644 index 000000000..88fcf2db3 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -0,0 +1,380 @@ +"""Unit tests for the dynamic OpenRouter integration.""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _OPENROUTER_DYNAMIC_MARKER, + _generate_configs, + _openrouter_tier, + _stable_config_id, +) + +pytestmark = pytest.mark.unit + + +def _minimal_openrouter_model( + *, + model_id: str, + pricing: dict | None = None, + name: str | None = None, +) -> dict: + """Return a synthetic OpenRouter /api/v1/models entry. + + The real API payload includes a lot of fields; we only populate what + ``_generate_configs`` actually inspects (architecture, tool support, + context, pricing, id). + """ + return { + "id": model_id, + "name": name or model_id, + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"}, + } + + +# --------------------------------------------------------------------------- +# _openrouter_tier +# --------------------------------------------------------------------------- + + +def test_openrouter_tier_free_suffix(): + assert _openrouter_tier({"id": "foo/bar:free"}) == "free" + + +def test_openrouter_tier_zero_pricing(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0", "completion": "0"}, + } + assert _openrouter_tier(model) == "free" + + +def test_openrouter_tier_paid(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + } + assert _openrouter_tier(model) == "premium" + + +def test_openrouter_tier_missing_pricing_is_premium(): + assert _openrouter_tier({"id": "foo/bar"}) == "premium" + assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium" + + +# --------------------------------------------------------------------------- +# _stable_config_id +# --------------------------------------------------------------------------- + + +def test_stable_config_id_deterministic(): + taken1: set[int] = set() + taken2: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken1) + b = _stable_config_id("openai/gpt-4o", -10_000, taken2) + assert a == b + assert a < 0 + + +def test_stable_config_id_collision_decrements(): + """When two model_ids hash to the same slot, the second should decrement.""" + taken: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken) + # Force a collision by pre-populating ``taken`` with a slot we know will be + # picked. + taken_forced = {a} + b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced) + assert b != a + assert b == a - 1 + assert b in taken_forced + + +def test_stable_config_id_different_models_different_ids(): + taken: set[int] = set() + ids = { + _stable_config_id("openai/gpt-4o", -10_000, taken), + _stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken), + _stable_config_id("google/gemini-2.0-flash", -10_000, taken), + } + assert len(ids) == 3 + + +def test_stable_config_id_survives_catalogue_churn(): + """Removing a model should not shift other models' IDs (the bug we fix).""" + taken1: set[int] = set() + id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1) + _ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1) + id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1) + + taken2: set[int] = set() + id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2) + id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2) + + assert id_a1 == id_a2 + assert id_c1 == id_c2 + + +# --------------------------------------------------------------------------- +# _generate_configs +# --------------------------------------------------------------------------- + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, +} + + +def test_generate_configs_respects_tier(): + """Premium OR models opt into the router pool; free OR models stay out. + + Strategy-3 split: premium participates in LiteLLM Router load balancing, + free stays excluded because OpenRouter enforces a shared global free-tier + bucket that per-deployment router accounting can't represent. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="meta-llama/llama-3.3-70b-instruct:free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + paid = by_model["openai/gpt-4o"] + assert paid["billing_tier"] == "premium" + assert paid["rpm"] == 200 + assert paid["tpm"] == 1_000_000 + assert paid["anonymous_enabled"] is False + assert paid["router_pool_eligible"] is True + assert paid[_OPENROUTER_DYNAMIC_MARKER] is True + + free = by_model["meta-llama/llama-3.3-70b-instruct:free"] + assert free["billing_tier"] == "free" + assert free["rpm"] == 20 + assert free["tpm"] == 100_000 + assert free["anonymous_enabled"] is True + assert free["router_pool_eligible"] is False + + +def test_generate_configs_excludes_upstream_openrouter_free_router(): + """OpenRouter's own ``openrouter/free`` meta-router must never become a card. + + The upstream API returns this as a first-class zero-priced model, so + without an explicit blocklist entry it would slip through every other + filter (text output, tool calling, 200k context, non-Amazon) and land + in the selector as a duplicate of the concrete ``:free`` cards. The + exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="openrouter/free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openrouter/free" not in model_names + assert "openai/gpt-4o" in model_names + + +def test_generate_configs_drops_non_text_and_non_tool_models(): + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + { # image-output model + "id": "openai/dall-e", + "architecture": {"output_modalities": ["image"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + { # text but no tool calling + "id": "openai/completion-only", + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": [], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = [c["model_name"] for c in cfgs] + assert "openai/gpt-4o" in model_names + assert "openai/dall-e" not in model_names + assert "openai/completion-only" not in model_names + + +# --------------------------------------------------------------------------- +# _generate_image_gen_configs / _generate_vision_llm_configs +# --------------------------------------------------------------------------- + + +def test_generate_image_gen_configs_filters_by_image_output(): + """Only models with ``output_modalities`` containing ``image`` are emitted. + Tool-calling and context filters are intentionally NOT applied — image + generation has nothing to do with tool calls and context windows. + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + # Pure image-gen model (small context, no tools — should still emit). + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Multi-modal: text+image output (should still emit). + { + "id": "google/gemini-2.5-flash-image", + "architecture": {"output_modalities": ["text", "image"]}, + "context_length": 1_000_000, + "pricing": {"prompt": "0.000001", "completion": "0.000004"}, + }, + # Pure text model — must NOT emit. + { + "id": "openai/gpt-4o", + "architecture": {"output_modalities": ["text"]}, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + ] + + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openai/gpt-image-1" in model_names + assert "google/gemini-2.5-flash-image" in model_names + assert "openai/gpt-4o" not in model_names + + # Each config must carry ``billing_tier`` for routing in image_generation_routes. + for c in cfgs: + assert c["billing_tier"] in {"free", "premium"} + assert c["provider"] == "OPENROUTER" + assert c[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't 404 against an inherited Azure endpoint. + assert c["api_base"] == "https://openrouter.ai/api/v1" + + +def test_generate_image_gen_configs_assigns_image_id_offset(): + """Image configs use a different id_offset (-20000) so their negative + IDs don't collide with chat configs (-10000) or vision configs (-30000). + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + } + ] + # Don't pass image_id_offset → use the module default (-20000). + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + assert all(c["id"] < -20_000 + 1 for c in cfgs) + assert all(c["id"] > -29_000_000 for c in cfgs) + + +def test_generate_vision_llm_configs_filters_by_image_input_text_output(): + """Vision LLMs must accept image input AND emit text — pure image-gen + (no text out) and text-only (no image in) models are excluded. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + # GPT-4o: vision LLM (image in, text out) — must emit. + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + # Pure image generator — image *output*, no text out. Must NOT emit. + { + "id": "openai/gpt-image-1", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["image"], + }, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Pure text model (no image in). Must NOT emit. + { + "id": "anthropic/claude-3-haiku", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "context_length": 200_000, + "pricing": {"prompt": "0.000001", "completion": "0.000005"}, + }, + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + names = {c["model_name"] for c in cfgs} + assert names == {"openai/gpt-4o"} + + cfg = cfgs[0] + assert cfg["billing_tier"] == "premium" + # Pricing carried inline so pricing_registration can register vision + # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache + # is cleared. + assert cfg["input_cost_per_token"] == pytest.approx(5e-6) + assert cfg["output_cost_per_token"] == pytest.approx(15e-6) + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't inherit an Azure endpoint. + assert cfg["api_base"] == "https://openrouter.ai/api/v1" + + +def test_generate_vision_llm_configs_drops_chat_only_filters(): + """A small-context vision model that doesn't advertise tool calling is + still a valid vision LLM for "describe this image" prompts. The chat + filters (``supports_tool_calling``, ``has_sufficient_context``) must + NOT be applied to vision emission. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + { + "id": "tiny/vision-mini", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": [], # no tools + "context_length": 4_000, # well below MIN_CONTEXT_LENGTH + "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, + } + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + assert len(cfgs) == 1 + assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py new file mode 100644 index 000000000..4eb1f2295 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py @@ -0,0 +1,108 @@ +"""Tests for deprecated-key warnings and back-compat in +``load_openrouter_integration_settings``. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +def _write_yaml(tmp_path: Path, body: str) -> Path: + cfg_dir = tmp_path / "app" / "config" + cfg_dir.mkdir(parents=True) + cfg_path = cfg_dir / "global_llm_config.yaml" + cfg_path.write_text(body, encoding="utf-8") + return cfg_path + + +def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + +def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + billing_tier: "premium" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert "billing_tier is deprecated" in captured + + +def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert settings["anonymous_enabled_paid"] is True + assert settings["anonymous_enabled_free"] is True + assert "anonymous_enabled is" in captured + assert "deprecated" in captured + + +def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys): + """If both legacy and new keys are present, new keys win (setdefault).""" + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true + anonymous_enabled_paid: false + anonymous_enabled_free: false +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + capsys.readouterr() + assert settings is not None + assert settings["anonymous_enabled_paid"] is False + assert settings["anonymous_enabled_free"] is False + + +def test_disabled_integration_returns_none(monkeypatch, tmp_path): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: false + api_key: "sk-or-test" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + assert load_openrouter_integration_settings() is None diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py new file mode 100644 index 000000000..1c74aa928 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -0,0 +1,331 @@ +"""Unit tests for the OpenRouter ``_enrich_health`` background task.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, +) +from app.services.quality_score import ( + _HEALTH_FAIL_RATIO_FALLBACK, +) + +pytestmark = pytest.mark.unit + + +def _or_cfg( + *, + cid: int, + model_name: str, + tier: str = "premium", + static_score: int = 50, +) -> dict: + return { + "id": cid, + "provider": "OPENROUTER", + "model_name": model_name, + "billing_tier": tier, + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_score, + "quality_score_health": None, + "quality_score": static_score, + "health_gated": False, + } + + +class _StubResponse: + def __init__(self, *, payload: dict, status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._payload + + +class _StubAsyncClient: + """Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``.""" + + def __init__(self, responder): + self._responder = responder + self.requests: list[str] = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url: str, headers: dict | None = None) -> _StubResponse: + self.requests.append(url) + return self._responder(url) + + +def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient: + """Replace ``httpx.AsyncClient`` for the duration of the test.""" + client = _StubAsyncClient(responder) + monkeypatch.setattr( + "app.services.openrouter_integration_service.httpx.AsyncClient", + lambda *_args, **_kwargs: client, + ) + return client + + +def _healthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + } + ] + } + } + + +def _unhealthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.62, + "uptime_last_5m": 0.50, + } + ] + } + } + + +# --------------------------------------------------------------------------- +# Bounded fan-out + happy path +# --------------------------------------------------------------------------- + + +async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch): + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="venice/dead-model", static_score=60), + ] + + def responder(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload=_healthy_payload()) + return _StubResponse(payload=_unhealthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {"api_key": ""} + await service._enrich_health(cfgs) + + healthy = next(c for c in cfgs if c["id"] == -1) + gated = next(c for c in cfgs if c["id"] == -2) + + assert healthy["health_gated"] is False + assert healthy["quality_score_health"] is not None + assert healthy["quality_score"] >= healthy["quality_score_static"] + + assert gated["health_gated"] is True + assert gated["quality_score"] == gated["quality_score_static"] + + +async def test_enrich_health_only_touches_or_provider(monkeypatch): + """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" + yaml_cfg = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score_static": 80, + "quality_score": 80, + "health_gated": False, + } + or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku") + + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg, or_cfg]) + + assert all("anthropic/claude-haiku" in r for r in requests) + # YAML cfg is untouched. + assert yaml_cfg["quality_score"] == 80 + assert yaml_cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Failure ratio fallback +# --------------------------------------------------------------------------- + + +async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high( + monkeypatch, +): + """If >= 25% of fetches fail, keep last-good cache instead of writing + partial data.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80), + _or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65), + _or_cfg(cid=-4, model_name="venice/something", static_score=50), + ] + + service = OpenRouterIntegrationService() + service._settings = {} + # Pre-seed last-good cache with a known-healthy snapshot. + service._health_cache = { + "anthropic/claude-haiku": {"gated": False, "score": 95.0}, + } + + def all_fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, all_fail) + await service._enrich_health(cfgs) + + # Above threshold ⇒ degraded; last-good cache wins for the cached cfg. + cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku") + assert cached_hit["quality_score_health"] == 95.0 + assert cached_hit["health_gated"] is False + # Confirm the threshold constant we're testing against is real. + assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0 + + +async def test_enrich_health_keeps_static_only_with_no_cache_and_failures( + monkeypatch, +): + """If a fetch fails and there's no last-good cache, the cfg keeps its + static-only ``quality_score`` and is *not* gated by default.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + ] + + def fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, fail) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + cfg = cfgs[0] + assert cfg["health_gated"] is False + assert cfg["quality_score"] == cfg["quality_score_static"] + assert cfg["quality_score_health"] is None + + +# --------------------------------------------------------------------------- +# Last-good cache: success populates, next failure reuses +# --------------------------------------------------------------------------- + + +async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure( + monkeypatch, +): + cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70) + + service = OpenRouterIntegrationService() + service._settings = {} + + def healthy(_url: str) -> _StubResponse: + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, healthy) + await service._enrich_health([cfg]) + + assert "anthropic/claude-haiku" in service._health_cache + cached_score = service._health_cache["anthropic/claude-haiku"]["score"] + assert cached_score is not None + + # Next cycle: enough other healthy cfgs so failure ratio stays below + # the 25% threshold even when this one fails individually. + other_cfgs = [ + _or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60) + for i in range(10) + ] + cfg["quality_score_health"] = None + cfg["quality_score"] = cfg["quality_score_static"] + + def mixed(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload={}, status_code=500) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, mixed) + await service._enrich_health([cfg, *other_cfgs]) + + assert cfg["quality_score_health"] == cached_score + assert cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Bounded fan-out: respects top-N caps +# --------------------------------------------------------------------------- + + +async def test_enrich_health_bounds_premium_fanout(monkeypatch): + """Top-N premium cap is honoured even when many cfgs are present.""" + from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM + + cfgs = [ + _or_cfg( + cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i + ) + for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20) + ] + + seen: list[str] = [] + + def responder(url: str) -> _StubResponse: + seen.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM + + +async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): + """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" + yaml_cfg: dict[str, Any] = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + } + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg]) + assert requests == [] diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py new file mode 100644 index 000000000..e97250ff2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -0,0 +1,447 @@ +"""Pricing registration unit tests. + +The pricing-registration module is what makes ``response_cost`` populate +correctly for OpenRouter dynamic models and operator-defined Azure +deployments — both of which LiteLLM doesn't natively know about. The tests +exercise: + +* The alias generators emit every shape that LiteLLM's cost-callback might + use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``, + ``provider/base_model``, ``provider/model_name``, plus the special + ``azure_openai`` → ``azure`` normalisation). +* ``register_pricing_from_global_configs`` calls ``litellm.register_model`` + with the right alias set and pricing values per provider. +* Configs without a resolvable pair of cost values are skipped — never + registered as zero, since that would override pricing LiteLLM might + already know natively. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Alias generators +# --------------------------------------------------------------------------- + + +def test_openrouter_alias_set_includes_prefixed_and_bare(): + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet") + assert aliases == [ + "openrouter/anthropic/claude-3-5-sonnet", + "anthropic/claude-3-5-sonnet", + ] + + +def test_openrouter_alias_set_dedupes(): + """If the model id is already prefixed with ``openrouter/``, the alias + set must not contain duplicates that would re-register the same key + twice. + """ + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("openrouter/foo") + # The bare and prefixed variants compute to the same string here, so we + # at minimum require uniqueness. + assert len(aliases) == len(set(aliases)) + + +def test_yaml_alias_set_for_azure_openai_normalises_to_azure(): + """``azure_openai`` (our YAML provider slug) must register under + ``azure/`` so the LiteLLM Router's deployment-resolution path + (which uses provider ``azure``) finds the pricing too. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="AZURE_OPENAI", + model_name="gpt-5.4", + base_model="gpt-5.4", + ) + assert "gpt-5.4" in aliases + assert "azure_openai/gpt-5.4" in aliases + assert "azure/gpt-5.4" in aliases + + +def test_yaml_alias_set_distinguishes_model_name_and_base_model(): + """When ``model_name`` differs from ``base_model`` (operator labelled a + deployment), both must appear in the alias set since either may surface + in callbacks depending on the call path. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="OPENAI", + model_name="my-deployment-label", + base_model="gpt-4o", + ) + assert "gpt-4o" in aliases + assert "openai/gpt-4o" in aliases + assert "my-deployment-label" in aliases + assert "openai/my-deployment-label" in aliases + + +def test_yaml_alias_set_omits_provider_prefix_when_provider_blank(): + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="", + model_name="foo", + base_model="bar", + ) + assert "bar" in aliases + assert "foo" in aliases + assert all("/" not in a for a in aliases) + + +# --------------------------------------------------------------------------- +# register_pricing_from_global_configs +# --------------------------------------------------------------------------- + + +class _RegistrationSpy: + """Captures the dicts passed to ``litellm.register_model``. + + Many calls may go through; we just record them all and let tests assert + against the union. + """ + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def __call__(self, payload: dict[str, Any]) -> None: + self.calls.append(payload) + + @property + def all_keys(self) -> set[str]: + keys: set[str] = set() + for payload in self.calls: + keys.update(payload.keys()) + return keys + + +def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy: + spy = _RegistrationSpy() + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + spy, + raising=False, + ) + return spy + + +def _patch_openrouter_pricing( + monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]] +) -> None: + """Pretend the OpenRouter integration is initialised with ``mapping``.""" + + class _Stub: + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + return mapping + + class _StubService: + @classmethod + def is_initialized(cls) -> bool: + return True + + @classmethod + def get_instance(cls) -> _Stub: + return _Stub() + + monkeypatch.setattr( + "app.services.openrouter_integration_service.OpenRouterIntegrationService", + _StubService, + raising=False, + ) + + +def test_openrouter_models_register_under_aliases(monkeypatch): + """An OpenRouter config whose ``model_name`` is in the cached raw + pricing map is registered under both ``openrouter/X`` and bare ``X``. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys + assert "anthropic/claude-3-5-sonnet" in spy.all_keys + # Costs are float-converted from the raw OpenRouter strings. + payload = spy.calls[0] + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "input_cost_per_token" + ] == pytest.approx(3e-6) + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "output_cost_per_token" + ] == pytest.approx(15e-6) + assert ( + payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"] + == "openrouter" + ) + + +def test_yaml_override_registers_under_alias_set(monkeypatch): + """Operator-declared ``input_cost_per_token`` / + ``output_cost_per_token`` on a YAML config registers under every + alias the YAML alias generator produces — including the ``azure/`` + normalisation for ``azure_openai`` providers. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "litellm_params": { + "base_model": "gpt-5.4", + "input_cost_per_token": 2e-6, + "output_cost_per_token": 8e-6, + }, + } + ], + ) + + register_pricing_from_global_configs() + + keys = spy.all_keys + assert "gpt-5.4" in keys + assert "azure_openai/gpt-5.4" in keys + assert "azure/gpt-5.4" in keys + + payload = spy.calls[0] + entry = payload["gpt-5.4"] + assert entry["input_cost_per_token"] == pytest.approx(2e-6) + assert entry["output_cost_per_token"] == pytest.approx(8e-6) + assert entry["litellm_provider"] == "azure" + + +def test_no_override_means_no_registration(monkeypatch): + """A YAML config that *omits* both pricing fields must NOT be registered + — registering as zero would override LiteLLM's native pricing for the + ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's + bill drop to $0. Fail-safe is "skip and warn", not "register zero". + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENAI", + "model_name": "gpt-4o", + "litellm_params": {"base_model": "gpt-4o"}, + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_openrouter_skipped_when_pricing_missing(monkeypatch): + """If the OpenRouter raw-pricing cache doesn't carry an entry for a + configured model (network blip during refresh, model added later, etc.), + we skip it rather than registering zero pricing. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}} + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_register_continues_after_individual_failure(monkeypatch, caplog): + """A single bad ``register_model`` call (e.g. raising LiteLLM error) + must not abort registration of the remaining configs. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"} + successful_calls: list[dict[str, Any]] = [] + + def _maybe_fail(payload: dict[str, Any]) -> None: + if any(k in failing_keys for k in payload): + raise RuntimeError("boom") + successful_calls.append(payload) + + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + _maybe_fail, + raising=False, + ) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + }, + { + "id": 2, + "provider": "OPENAI", + "model_name": "custom-deployment", + "litellm_params": { + "base_model": "custom-deployment", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 2e-6, + }, + }, + ], + ) + + register_pricing_from_global_configs() + + # The good config still registered. + assert any("custom-deployment" in payload for payload in successful_calls) + + +def test_vision_configs_registered_with_chat_shape(monkeypatch): + """``register_pricing_from_global_configs`` walks + ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision + calls (during indexing) bill correctly. Vision configs use the same + chat-shape token prices, but image-gen pricing is intentionally NOT + registered here (handled via ``response_cost`` in LiteLLM). + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, + ) + + # No chat configs — only vision. Proves the vision walk is a separate + # iteration, not piggy-backed on the chat list. + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "billing_tier": "premium", + "input_cost_per_token": 5e-6, + "output_cost_per_token": 15e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/openai/gpt-4o" in spy.all_keys + payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] + assert payload_value["mode"] == "chat" + assert payload_value["litellm_provider"] == "openrouter" + assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) + assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) + + +def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): + """If the OpenRouter pricing cache misses a vision model (different + catalogue surface), the vision walk falls back to inline + ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash", + "billing_tier": "premium", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 4e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py new file mode 100644 index 000000000..12cd0a3d5 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py @@ -0,0 +1,107 @@ +"""Unit tests for the shared ``api_base`` resolver. + +The cascade exists so vision and image-gen call sites can't silently +inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) +when an OpenRouter / Groq / etc. config ships an empty string. See +``provider_api_base`` module docstring for the original repro +(OpenRouter image-gen 404-ing against an Azure endpoint). +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_api_base import ( + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) + +pytestmark = pytest.mark.unit + + +def test_config_value_wins_over_defaults(): + """A non-empty config value is always returned verbatim, even when the + provider has a default — the operator gets the last word.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="https://my-openrouter-mirror.example.com/v1", + ) + assert result == "https://my-openrouter-mirror.example.com/v1" + + +def test_provider_key_default_when_config_missing(): + """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own + base URL — the provider-key map must take precedence over the prefix + map so DeepSeek requests don't go to OpenAI.""" + result = resolve_api_base( + provider="DEEPSEEK", + provider_prefix="openai", + config_api_base=None, + ) + assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_provider_prefix_default_when_no_key_default(): + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=None, + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_unknown_provider_returns_none(): + """When neither map matches we return ``None`` so the caller can let + LiteLLM apply its own provider-integration default (Azure deployment + URL, custom-provider URL, etc.).""" + result = resolve_api_base( + provider="SOMETHING_NEW", + provider_prefix="something_new", + config_api_base=None, + ) + assert result is None + + +def test_empty_string_config_treated_as_missing(): + """The original bug: OpenRouter dynamic configs ship ``api_base=""`` + and downstream call sites use ``if cfg.get("api_base"):`` — empty + strings are falsy in Python but the cascade has to step in anyway.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_whitespace_only_config_treated_as_missing(): + """A config value of ``" "`` is a configuration mistake — treat it + as missing instead of forwarding whitespace to LiteLLM (which would + almost certainly 404).""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=" ", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_provider_case_insensitive(): + """Some call sites pass the provider lowercase (DB enum value), others + uppercase (YAML key). Both must resolve.""" + upper = resolve_api_base( + provider="DEEPSEEK", provider_prefix="openai", config_api_base=None + ) + lower = resolve_api_base( + provider="deepseek", provider_prefix="openai", config_api_base=None + ) + assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_all_inputs_none_returns_none(): + assert ( + resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) + is None + ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py new file mode 100644 index 000000000..aac88977f --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -0,0 +1,244 @@ +"""Unit tests for the shared chat-image capability resolver. + +Two resolvers, two intents: + +- ``derive_supports_image_input`` — best-effort True for the catalog and + selector. Default-allow on unknown / unmapped models. The streaming + task safety net never sees this value directly. + +- ``is_known_text_only_chat_model`` — strict opt-out for the safety net. + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False``. Anything else (missing key, exception, + True) returns False so the request flows through to the provider. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import ( + derive_supports_image_input, + is_known_text_only_chat_model, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — OpenRouter modalities path (authoritative) +# --------------------------------------------------------------------------- + + +def test_or_modalities_with_image_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="openai/gpt-4o", + openrouter_input_modalities=["text", "image"], + ) + is True + ) + + +def test_or_modalities_text_only_returns_false(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="deepseek/deepseek-v3.2-exp", + openrouter_input_modalities=["text"], + ) + is False + ) + + +def test_or_modalities_empty_list_returns_false(): + """OR explicitly publishing an empty modality list is a definitive + 'no inputs at all' signal — treat as False rather than falling back + to LiteLLM.""" + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="weird/empty-modalities", + openrouter_input_modalities=[], + ) + is False + ) + + +def test_or_modalities_none_falls_through_to_litellm(): + """``None`` (missing key) is *not* a definitive signal — fall through + to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + openrouter_input_modalities=None, + ) + is True + ) + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — LiteLLM model-map path +# --------------------------------------------------------------------------- + + +def test_litellm_known_vision_model_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + ) + is True + ) + + +def test_litellm_base_model_wins_over_model_name(): + """Azure-style entries pass model_name=deployment_id and put the + canonical sku in litellm_params.base_model. The resolver must + consult base_model first or the deployment id (which LiteLLM + doesn't know) would shadow the real capability.""" + assert ( + derive_supports_image_input( + provider="AZURE_OPENAI", + model_name="my-azure-deployment-id", + base_model="gpt-4o", + ) + is True + ) + + +def test_litellm_unknown_model_default_allows(): + """Default-allow on unknown — the safety net is the actual block.""" + assert ( + derive_supports_image_input( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is True + ) + + +def test_litellm_known_text_only_returns_false(): + """A model that LiteLLM explicitly knows is text-only resolves to + False even via the catalog resolver. ``deepseek-chat`` (the + DeepSeek-V3 chat sku) is in the map without supports_vision and + LiteLLM's `supports_vision` returns False.""" + # Sanity: confirm the helper's negative path. We use a small model + # known not to support vision per the map. + result = derive_supports_image_input( + provider="DEEPSEEK", + model_name="deepseek-chat", + ) + # We accept either False (LiteLLM said explicit no) or True + # (default-allow if the entry isn't mapped on this version) — the + # invariant is that the resolver never *raises* on a known-text-only + # provider/model. The behaviour-binding assertion lives in + # ``test_is_known_text_only_chat_model_explicit_false`` below. + assert isinstance(result, bool) + + +# --------------------------------------------------------------------------- +# is_known_text_only_chat_model — strict opt-out semantics +# --------------------------------------------------------------------------- + + +def test_is_known_text_only_returns_false_for_vision_model(): + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_false_for_unknown_model(): + """Strict opt-out: missing from the map ≠ text-only. The safety net + must NOT fire for an unmapped model — that's the regression we're + fixing.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is False + ) + + +def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): + """LiteLLM's ``get_model_info`` raises freely on parse errors. The + helper swallows the exception and returns False so the safety net + doesn't fire on a transient lookup failure.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise ValueError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): + """Stub LiteLLM's ``get_model_info`` to return an explicit False so + we exercise the opt-out path deterministically. Using a stub keeps + the test stable across LiteLLM map updates.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is True + ) + + +def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) + + +def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): + """A model entry without ``supports_vision`` at all is treated as + 'unknown' — strict opt-out means False.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"max_input_tokens": 8192} # no supports_vision + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py new file mode 100644 index 000000000..6fbc8fd62 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -0,0 +1,345 @@ +"""Unit tests for the Auto (Fastest) quality scoring module.""" + +from __future__ import annotations + +import time + +import pytest + +from app.services.quality_score import ( + _HEALTH_GATE_UPTIME_PCT, + _OPERATOR_TRUST_BONUS, + aggregate_health, + capabilities_signal, + context_signal, + created_recency_signal, + pricing_band, + slug_penalty, + static_score_or, + static_score_yaml, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# created_recency_signal +# --------------------------------------------------------------------------- + + +def test_created_recency_signal_recent_model_scores_high(): + now = 1_750_000_000 # ~mid-2025 + one_month_ago = now - (30 * 86_400) + assert created_recency_signal(one_month_ago, now) == 20 + + +def test_created_recency_signal_old_model_scores_zero(): + now = 1_750_000_000 + five_years_ago = now - (5 * 365 * 86_400) + assert created_recency_signal(five_years_ago, now) == 0 + + +def test_created_recency_signal_missing_timestamp_is_neutral(): + now = 1_750_000_000 + assert created_recency_signal(None, now) == 0 + assert created_recency_signal(0, now) == 0 + + +def test_created_recency_signal_monotonic_decay(): + now = 1_750_000_000 + scores = [ + created_recency_signal(now - days * 86_400, now) + for days in (30, 120, 300, 500, 700, 1000, 1500) + ] + assert scores == sorted(scores, reverse=True) + + +# --------------------------------------------------------------------------- +# pricing_band +# --------------------------------------------------------------------------- + + +def test_pricing_band_free_returns_zero(): + assert pricing_band("0", "0") == 0 + assert pricing_band(0.0, 0.0) == 0 + assert pricing_band(None, None) == 0 + + +def test_pricing_band_handles_unparseable(): + assert pricing_band("not-a-number", "0") == 0 + assert pricing_band({}, []) == 0 # type: ignore[arg-type] + + +def test_pricing_band_premium_tiers_increase_with_price(): + cheap = pricing_band("0.0000003", "0.0000005") + mid = pricing_band("0.000003", "0.000015") + flagship = pricing_band("0.00001", "0.00005") + assert 0 < cheap < mid < flagship + + +# --------------------------------------------------------------------------- +# context_signal +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "ctx,expected", + [ + (1_500_000, 10), + (1_000_000, 10), + (500_000, 8), + (200_000, 6), + (128_000, 4), + (100_000, 2), + (50_000, 0), + (0, 0), + (None, 0), + ], +) +def test_context_signal_bands(ctx, expected): + assert context_signal(ctx) == expected + + +# --------------------------------------------------------------------------- +# capabilities_signal +# --------------------------------------------------------------------------- + + +def test_capabilities_signal_caps_at_five(): + assert ( + capabilities_signal( + ["tools", "structured_outputs", "reasoning", "include_reasoning"] + ) + <= 5 + ) + + +def test_capabilities_signal_tools_only(): + assert capabilities_signal(["tools"]) == 2 + + +def test_capabilities_signal_empty(): + assert capabilities_signal(None) == 0 + assert capabilities_signal([]) == 0 + + +# --------------------------------------------------------------------------- +# slug_penalty +# --------------------------------------------------------------------------- + + +def test_slug_penalty_demotes_tiny_models(): + assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0 + assert slug_penalty("liquid/lfm-7b") < 0 + assert slug_penalty("google/gemma-3n-e4b-it") < 0 + + +def test_slug_penalty_skips_capable_mini_nano_lite_models(): + """Critical Option C+ regression: don't penalise modern frontier + models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.).""" + assert slug_penalty("openai/gpt-5-mini") == 0 + assert slug_penalty("openai/gpt-5-nano") == 0 + assert slug_penalty("google/gemini-2.5-flash-lite") == 0 + assert slug_penalty("anthropic/claude-haiku-4.5") == 0 + + +def test_slug_penalty_demotes_legacy_variants(): + assert slug_penalty("openai/o1-preview") < 0 + assert slug_penalty("foo/bar-base") < 0 + assert slug_penalty("foo/bar-distill") < 0 + + +def test_slug_penalty_empty_input(): + assert slug_penalty("") == 0 + + +# --------------------------------------------------------------------------- +# static_score_or +# --------------------------------------------------------------------------- + + +def _or_model( + *, + model_id: str, + created: int | None = None, + prompt: str = "0.000003", + completion: str = "0.000015", + context: int = 200_000, + params: list[str] | None = None, +) -> dict: + return { + "id": model_id, + "created": created, + "pricing": {"prompt": prompt, "completion": completion}, + "context_length": context, + "supported_parameters": params if params is not None else ["tools"], + } + + +def test_static_score_or_frontier_premium_beats_free_tiny(): + now = 1_750_000_000 + frontier = _or_model( + model_id="openai/gpt-5", + created=now - (60 * 86_400), + prompt="0.000005", + completion="0.000020", + context=400_000, + params=["tools", "structured_outputs", "reasoning"], + ) + tiny_free = _or_model( + model_id="meta-llama/llama-3.2-1b-instruct:free", + created=now - (5 * 365 * 86_400), + prompt="0", + completion="0", + context=128_000, + params=["tools"], + ) + assert static_score_or(frontier, now_ts=now) > static_score_or( + tiny_free, now_ts=now + ) + + +def test_static_score_or_score_is_clamped_0_to_100(): + now = int(time.time()) + score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now) + assert 0 <= score <= 100 + + +def test_static_score_or_unknown_provider_is_neutral_not_zero(): + now = int(time.time()) + score = static_score_or( + _or_model(model_id="some-new-lab/some-model"), + now_ts=now, + ) + assert score > 0 + + +def test_static_score_or_recent_release_beats_year_old_same_provider(): + now = 1_750_000_000 + fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400)) + old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400)) + assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now) + + +# --------------------------------------------------------------------------- +# static_score_yaml +# --------------------------------------------------------------------------- + + +def test_static_score_yaml_includes_operator_bonus(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_unknown_provider_still_carries_bonus(): + cfg = { + "provider": "SOME_NEW_PROVIDER", + "model_name": "weird-model", + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_clamped_0_to_100(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + assert 0 <= static_score_yaml(cfg) <= 100 + + +# --------------------------------------------------------------------------- +# aggregate_health +# --------------------------------------------------------------------------- + + +def test_aggregate_health_gates_when_uptime_below_threshold(): + """Live data showed Venice-routed cfgs at 53-68%; this guards that the + 90% gate excludes them.""" + venice_endpoints = [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.60, + "uptime_last_5m": 0.50, + }, + { + "status": 0, + "uptime_last_30m": 0.65, + "uptime_last_1d": 0.68, + "uptime_last_5m": 0.62, + }, + ] + gated, score = aggregate_health(venice_endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_passes_for_healthy_provider(): + healthy = [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + }, + ] + gated, score = aggregate_health(healthy) + assert gated is False + assert score is not None + assert score >= _HEALTH_GATE_UPTIME_PCT + + +def test_aggregate_health_picks_best_endpoint_across_multiple(): + """Multi-endpoint aggregation should reward the best non-null uptime.""" + mixed = [ + {"status": 0, "uptime_last_30m": 0.55}, + {"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate + ] + gated, score = aggregate_health(mixed) + assert gated is False + assert score is not None + + +def test_aggregate_health_empty_endpoints_gated(): + gated, score = aggregate_health([]) + assert gated is True + assert score is None + + +def test_aggregate_health_no_status_zero_gated(): + """Even with high uptime, no OK status means the cfg is broken upstream.""" + endpoints = [ + {"status": 1, "uptime_last_30m": 0.99}, + {"status": 2, "uptime_last_30m": 0.98}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_all_uptime_null_gated(): + endpoints = [ + {"status": 0, "uptime_last_30m": None, "uptime_last_1d": None}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_pct_normalisation(): + """OpenRouter returns 0-1 fractions; some endpoints surface 0-100% + percentages. Both should reach the same gate decision.""" + fraction_form = [{"status": 0, "uptime_last_30m": 0.95}] + pct_form = [{"status": 0, "uptime_last_30m": 95.0}] + g1, s1 = aggregate_health(fraction_form) + g2, s2 = aggregate_health(pct_form) + assert g1 == g2 == False # noqa: E712 + assert s1 is not None and s2 is not None + assert abs(s1 - s2) < 0.5 diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py new file mode 100644 index 000000000..9e35b6f9c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -0,0 +1,157 @@ +"""Unit tests for ``QuotaCheckedVisionLLM``. + +Validates that: + +* Calling ``ainvoke`` routes through ``billable_call`` (premium credit + enforcement) and forwards the inner LLM's response on success. +* The wrapper proxies non-overridden attributes to the inner LLM + (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output`` + still work without quota gating (they're not used in indexing today). +* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper + bubbles it up — the ETL pipeline catches that and falls back to OCR. +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +class _FakeInnerLLM: + """Stand-in for ``langchain_litellm.ChatLiteLLM``.""" + + def __init__(self, response: Any = "OCR'd content") -> None: + self._response = response + self.ainvoke_calls: list[Any] = [] + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + self.ainvoke_calls.append(input) + return self._response + + def some_other_method(self, x: int) -> int: + return x * 2 + + +@contextlib.asynccontextmanager +async def _passthrough_billable_call(**_kwargs): + """Stand-in for billable_call that always allows the call to run.""" + + class _Acc: + total_cost_micros = 0 + total_prompt_tokens = 0 + total_completion_tokens = 0 + grand_total = 0 + calls: list[Any] = [] + + def per_message_summary(self) -> dict[str, dict[str, int]]: + return {} + + yield _Acc() + + +@pytest.mark.asyncio +async def test_ainvoke_routes_through_billable_call(monkeypatch): + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + captured_kwargs: list[dict[str, Any]] = [] + + @contextlib.asynccontextmanager + async def _spy_billable_call(**kwargs): + captured_kwargs.append(kwargs) + async with _passthrough_billable_call() as acc: + yield acc + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _spy_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM(response="A red apple on a white table") + user_id = uuid4() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=user_id, + search_space_id=99, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + result = await wrapper.ainvoke([{"text": "what is this?"}]) + assert result == "A red apple on a white table" + assert len(inner.ainvoke_calls) == 1 + assert len(captured_kwargs) == 1 + bc_kwargs = captured_kwargs[0] + assert bc_kwargs["user_id"] == user_id + assert bc_kwargs["search_space_id"] == 99 + assert bc_kwargs["billing_tier"] == "premium" + assert bc_kwargs["base_model"] == "openai/gpt-4o" + assert bc_kwargs["quota_reserve_tokens"] == 4000 + assert bc_kwargs["usage_type"] == "vision_extraction" + + +@pytest.mark.asyncio +async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): + from app.services.billable_calls import QuotaInsufficientError + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + @contextlib.asynccontextmanager + async def _denying_billable_call(**_kwargs): + raise QuotaInsufficientError( + usage_type="vision_extraction", + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield # unreachable but required for asynccontextmanager type + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _denying_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + with pytest.raises(QuotaInsufficientError): + await wrapper.ainvoke([{"text": "x"}]) + + # Inner LLM never ran on a denied reservation. + assert inner.ainvoke_calls == [] + + +@pytest.mark.asyncio +async def test_proxies_non_overridden_attributes_to_inner(): + """``__getattr__`` forwards anything not on the proxy itself, so any + method we didn't explicitly override (``invoke``, ``astream``, + ``with_structured_output``, etc.) still works — just without quota + gating, which is fine because the indexer only ever calls ainvoke. + """ + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + # ``some_other_method`` is on the inner only. + assert wrapper.some_other_method(7) == 14 diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py new file mode 100644 index 000000000..71fdee1c7 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py @@ -0,0 +1,281 @@ +"""Unit tests for the chat-catalog ``supports_image_input`` capability flag. + +Capability is sourced from two places, in order of preference: + +1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs + (authoritative — OpenRouter publishes per-model modalities directly). +2. LiteLLM's authoritative model map (``litellm.supports_vision``) for + YAML / BYOK configs that don't carry an explicit operator override. + +The catalog default is *True* (conservative-allow): an unknown / unmapped +model is not pre-judged. The streaming-task safety net +(``is_known_text_only_chat_model``) is the only place a False actually +blocks a request — and it requires LiteLLM to *explicitly* mark the model +as text-only. +""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _OPENROUTER_DYNAMIC_MARKER, + _generate_configs, + _supports_image_input, +) + +pytestmark = pytest.mark.unit + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, +} + + +# --------------------------------------------------------------------------- +# _supports_image_input helper (OpenRouter modalities) +# --------------------------------------------------------------------------- + + +def test_supports_image_input_true_for_multimodal(): + assert ( + _supports_image_input( + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + } + ) + is True + ) + + +def test_supports_image_input_false_for_text_only(): + """The exact failure mode the safety net guards against — DeepSeek V3 + is a text-in/text-out model and would 404 if forwarded image_url.""" + assert ( + _supports_image_input( + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + } + ) + is False + ) + + +def test_supports_image_input_false_when_modalities_missing(): + """Defensive: missing architecture is treated as text-only at the + OpenRouter helper level. The wider catalog resolver + (`derive_supports_image_input`) only consults modalities when they + are non-empty, otherwise it falls back to LiteLLM.""" + assert _supports_image_input({"id": "weird/model"}) is False + assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False + assert ( + _supports_image_input( + {"id": "weird/model", "architecture": {"input_modalities": None}} + ) + is False + ) + + +# --------------------------------------------------------------------------- +# _generate_configs threads the flag onto every emitted chat config +# --------------------------------------------------------------------------- + + +def test_generate_configs_emits_supports_image_input(): + raw = [ + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + gpt = by_model["openai/gpt-4o"] + assert gpt["supports_image_input"] is True + assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True + + deepseek = by_model["deepseek/deepseek-v3.2-exp"] + assert deepseek["supports_image_input"] is False + assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True + + +# --------------------------------------------------------------------------- +# YAML loader: defer to derive_supports_image_input on unannotated entries +# --------------------------------------------------------------------------- + + +def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch): + """The regression case: an Azure GPT-5.x YAML entry without a + ``supports_image_input`` override should resolve to True via LiteLLM's + model map (which says ``supports_vision: true``). Previously this + defaulted to False, blocking every image turn for vision-capable + YAML configs.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -2 + name: Azure GPT-4o + provider: AZURE_OPENAI + model_name: gpt-4o + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch): + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: GPT-4o + provider: OPENAI + model_name: gpt-4o + api_key: sk-test + supports_image_input: false +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + # Operator override always wins, even against LiteLLM's True. + assert configs[0]["supports_image_input"] is False + + +def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch): + """Unknown / unmapped model in YAML: default-allow. The streaming + safety net (which requires an explicit-False from LiteLLM) is the + only place a real block happens, so we don't lock the user out of + a freshly added third-party entry the catalog can't introspect.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: Some Brand New Model + provider: CUSTOM + custom_provider: brand_new_proxy + model_name: brand-new-model-x9 + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +# --------------------------------------------------------------------------- +# AgentConfig threads the flag through both YAML and Auto / BYOK +# --------------------------------------------------------------------------- + + +def test_agent_config_from_yaml_explicit_overrides_resolver(): + from app.agents.new_chat.llm_config import AgentConfig + + cfg_text_only = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "Text Only Override", + "provider": "openai", + "model_name": "gpt-4o", # Capable per LiteLLM, but operator says no. + "api_key": "sk-test", + "supports_image_input": False, + } + ) + cfg_explicit_vision = AgentConfig.from_yaml_config( + { + "id": -2, + "name": "GPT-4o", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + "supports_image_input": True, + } + ) + assert cfg_text_only.supports_image_input is False + assert cfg_explicit_vision.supports_image_input is True + + +def test_agent_config_from_yaml_unannotated_uses_resolver(): + """Without an explicit YAML key, AgentConfig defers to the catalog + resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True.""" + from app.agents.new_chat.llm_config import AgentConfig + + cfg = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "GPT-4o (no override)", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + } + ) + assert cfg.supports_image_input is True + + +def test_agent_config_auto_mode_supports_image_input(): + """Auto routes across the pool. We optimistically allow image input + so users can keep their selection on Auto with a vision-capable + deployment somewhere in the pool. The router's own `allowed_fails` + handles non-vision deployments via fallback.""" + from app.agents.new_chat.llm_config import AgentConfig + + auto = AgentConfig.from_auto_mode() + assert auto.supports_image_input is True diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py new file mode 100644 index 000000000..63681828d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py @@ -0,0 +1,515 @@ +"""Cost-based premium quota unit tests. + +Covers the USD-micro behaviour added in migration 140: + +* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all + calls in a turn — used as the debit amount when ``agent_config.is_premium`` + is true, regardless of which underlying model produced each call. This + preserves the prior "premium turn → all calls in turn count" rule from the + token-based system. +* ``estimate_call_reserve_micros`` scales linearly with model pricing, + clamps to a sane floor when pricing is unknown, and respects the + ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry + can't lock the whole balance on one call. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# TurnTokenAccumulator — premium-turn debit semantics +# --------------------------------------------------------------------------- + + +def test_total_cost_micros_sums_premium_and_free_calls(): + """A premium turn that also called a free sub-agent debits the union. + + The plan deliberately preserved the existing "premium turn → all calls + count" behaviour because per-call premium filtering relied on + ``LLMRouterService._premium_model_strings`` which only covers router-pool + deployments. ``total_cost_micros`` therefore must include free-model + calls (whose ``cost_micros`` is typically ``0``) as well as the premium + call's actual provider cost. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + # Premium model (e.g. claude-opus): non-zero cost. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=1200, + completion_tokens=400, + total_tokens=1600, + cost_micros=12_345, + ) + # Free sub-agent (e.g. title-gen on a free model): zero cost. + acc.add( + model="gpt-4o-mini", + prompt_tokens=120, + completion_tokens=20, + total_tokens=140, + cost_micros=0, + ) + # A second premium-priced call within the same turn. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=800, + completion_tokens=200, + total_tokens=1000, + cost_micros=7_500, + ) + + assert acc.total_cost_micros == 12_345 + 0 + 7_500 + # Token totals stay correct so the FE display path still works. + assert acc.grand_total == 1600 + 140 + 1000 + + +def test_total_cost_micros_zero_when_no_calls(): + """An empty accumulator must report zero cost (no division-by-zero, no None).""" + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + assert acc.total_cost_micros == 0 + assert acc.grand_total == 0 + + +def test_per_message_summary_groups_cost_by_model(): + """``per_message_summary`` must accumulate ``cost_micros`` per model so the + SSE ``model_breakdown`` payload reports actual USD spend per provider. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cost_micros=4_000, + ) + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=200, + completion_tokens=100, + total_tokens=300, + cost_micros=8_000, + ) + acc.add( + model="gpt-4o-mini", + prompt_tokens=50, + completion_tokens=10, + total_tokens=60, + cost_micros=200, + ) + + summary = acc.per_message_summary() + assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000 + assert summary["claude-3-5-sonnet"]["total_tokens"] == 450 + assert summary["gpt-4o-mini"]["cost_micros"] == 200 + + +def test_serialized_calls_includes_cost_micros(): + """``serialized_calls`` is what flows into the SSE ``call_details`` + payload; cost_micros must be present on each entry so the FE message-info + dropdown can render per-call USD. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="m", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=42, + ) + serialized = acc.serialized_calls() + assert serialized == [ + { + "model": "m", + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + "cost_micros": 42, + "call_kind": "chat", + } + ] + + +# --------------------------------------------------------------------------- +# estimate_call_reserve_micros — sizing and clamping +# --------------------------------------------------------------------------- + + +def test_reserve_returns_floor_when_model_unknown(monkeypatch): + """If LiteLLM doesn't know the model, ``get_model_info`` raises and the + helper falls back to the 100-micro floor — small enough that a user with + $0.0001 left can still send a tiny request, but non-zero so we still gate + against an empty balance. + """ + import litellm + + from app.services import token_quota_service + + def _raise(_name): + raise KeyError("unknown") + + monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="nonexistent-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + assert micros == 100 + + +def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch): + """LiteLLM may *return* a model with both cost-per-token fields at 0 + (pricing not yet registered). The helper must not multiply 0 x tokens + and end up reserving 0 — it must clamp to the floor. + """ + import litellm + + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0}, + raising=False, + ) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="some-pending-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + + +def test_reserve_scales_with_model_cost(monkeypatch): + """Claude-Opus-priced model with 4000 reserve_tokens reserves + ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to + some small artificial cap — that was the bug the plan called out. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 15e-6, + "output_cost_per_token": 75e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="claude-3-opus", + quota_reserve_tokens=4000, + ) + # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros. + assert micros == 360_000 + + +def test_reserve_clamps_to_max_ceiling(monkeypatch): + """A misconfigured "$1000 / M" model with 4000 reserve_tokens would + nominally compute to $4 = 4_000_000 micros. The ceiling + ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry + can't lock the user's whole balance on one call. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-3, + "output_cost_per_token": 0, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="oops-misconfigured", + quota_reserve_tokens=4000, + ) + assert micros == 1_000_000 + + +def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch): + """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or + zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL`` + so anonymous-style configs still reserve the operator-tunable default. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-6, + "output_cost_per_token": 1e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=None + ) + == 4000 + ) + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=0 + ) + == 4000 + ) + + +# --------------------------------------------------------------------------- +# TokenTrackingCallback — image vs chat usage shape +# --------------------------------------------------------------------------- + + +class _FakeImageUsage: + """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape).""" + + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: int | None = None, + ) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + if total_tokens is not None: + self.total_tokens = total_tokens + + +class _FakeImageResponse: + """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's + ``type(...).__name__`` probe routes to the image branch. + """ + + def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None): + self.usage = usage + if response_cost is not None: + self._hidden_params = {"response_cost": response_cost} + + +# Re-tag the helper class as ``ImageResponse`` for the type-name probe in +# the callback. We can't simply name the class ``ImageResponse`` because +# the test runner sometimes imports test modules in surprising ways and +# we want to be explicit. +_FakeImageResponse.__name__ = "ImageResponse" + + +class _FakeChatUsage: + def __init__(self, prompt: int, completion: int): + self.prompt_tokens = prompt + self.completion_tokens = completion + self.total_tokens = prompt + completion + + +class _FakeChatResponse: + def __init__(self, usage: _FakeChatUsage): + self.usage = usage + + +@pytest.mark.asyncio +async def test_callback_reads_image_usage_input_output_tokens(): + """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens`` + for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT + prompt_tokens/completion_tokens which is the chat shape. + """ + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50), + response_cost=0.04, # $0.04 per image + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04}, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 42 + assert call.completion_tokens == 8 + assert call.total_tokens == 50 + # 0.04 USD = 40_000 micros + assert call.cost_micros == 40_000 + assert call.call_kind == "image_generation" + + +@pytest.mark.asyncio +async def test_callback_chat_path_unchanged(): + """Chat responses must still read prompt_tokens/completion_tokens.""" + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30)) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={ + "model": "openrouter/anthropic/claude-3-5-sonnet", + "response_cost": 0.0036, + }, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 120 + assert call.completion_tokens == 30 + assert call.total_tokens == 150 + assert call.cost_micros == 3_600 + assert call.call_kind == "chat" + + +@pytest.mark.asyncio +async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch): + """When OpenRouter omits ``usage.cost`` LiteLLM's + ``default_image_cost_calculator`` raises. The defensive image branch in + ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is + chat-shaped and would raise too) — it returns 0 with a WARNING log. + """ + import litellm + + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + # Force completion_cost to raise the same way OpenRouter image-gen fails. + def _boom(*_args, **_kwargs): + raise ValueError("model_cost: missing entry for openrouter image model") + + monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False) + + # And make sure cost_per_token is NEVER called for the image path — + # if it were, our ``is_image=True`` branch is broken. + cost_per_token_calls: list = [] + + def _record_cost_per_token(**kwargs): + cost_per_token_calls.append(kwargs) + return (0.0, 0.0) + + monkeypatch.setattr( + litellm, "cost_per_token", _record_cost_per_token, raising=False + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=7, output_tokens=0) + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openrouter/google/gemini-2.5-flash-image"}, + response_obj=response, + start_time=None, + end_time=None, + ) + + assert len(acc.calls) == 1 + assert acc.calls[0].cost_micros == 0 + assert acc.calls[0].call_kind == "image_generation" + # The image branch must short-circuit before cost_per_token. + assert cost_per_token_calls == [] + + +# --------------------------------------------------------------------------- +# scoped_turn — ContextVar reset semantics (issue B) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scoped_turn_restores_outer_accumulator(): + """``scoped_turn`` must restore the previous ContextVar value on exit + so a per-call wrapper inside an outer chat turn doesn't leak its + accumulator outward (which would cause double-debit at chat-turn exit). + """ + from app.services.token_tracking_service import ( + get_current_accumulator, + scoped_turn, + start_turn, + ) + + outer = start_turn() + assert get_current_accumulator() is outer + + async with scoped_turn() as inner: + assert get_current_accumulator() is inner + assert inner is not outer + inner.add( + model="x", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=5, + ) + + # After exit the outer accumulator is restored unchanged. + assert get_current_accumulator() is outer + assert outer.total_cost_micros == 0 + assert len(outer.calls) == 0 + # The inner accumulator captured the call but didn't bleed into outer. + assert inner.total_cost_micros == 5 + + +@pytest.mark.asyncio +async def test_scoped_turn_resets_to_none_when_no_outer(): + """Running ``scoped_turn`` outside any chat turn (e.g. a background + indexing job) must leave the ContextVar at ``None`` on exit so the + next *unrelated* request starts clean. + """ + from app.services.token_tracking_service import ( + _turn_accumulator, + get_current_accumulator, + scoped_turn, + ) + + # ContextVar default is None for a fresh test isolated context. We + # simulate "no outer" explicitly to be robust against test order. + token = _turn_accumulator.set(None) + try: + assert get_current_accumulator() is None + async with scoped_turn() as acc: + assert get_current_accumulator() is acc + assert get_current_accumulator() is None + finally: + _turn_accumulator.reset(token) diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py new file mode 100644 index 000000000..b8ba9d80c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -0,0 +1,89 @@ +"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` +defaults from ``litellm.api_base`` either. + +Vision shares the same shape as image-gen — global YAML / OpenRouter +dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` +call sites would silently drop the empty string and inherit +``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on +construction so we test the kwargs we hand to it instead. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_get_vision_llm_global_openrouter_sets_api_base(): + """Global negative-ID branch: an OpenRouter vision config with + ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with + ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, + never silently absent.""" + from app.services import llm_service + + cfg = { + "id": -30_001, + "name": "GPT-4o Vision (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + "billing_tier": "free", + } + + search_space = MagicMock() + search_space.id = 1 + search_space.user_id = "user-x" + search_space.vision_llm_config_id = cfg["id"] + + session = AsyncMock() + scalars = MagicMock() + scalars.first.return_value = search_space + result = MagicMock() + result.scalars.return_value = scalars + session.execute.return_value = result + + captured: dict = {} + + class FakeSanitized: + def __init__(self, **kwargs): + captured.update(kwargs) + + with ( + patch( + "app.services.vision_llm_router_service.get_global_vision_llm_config", + return_value=cfg, + ), + patch( + "app.agents.new_chat.llm_config.SanitizedChatLiteLLM", + new=FakeSanitized, + ), + ): + await llm_service.get_vision_llm(session=session, search_space_id=1) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-4o" + + +def test_vision_router_deployment_sets_api_base_when_config_empty(): + """Auto-mode vision router: deployments are fed to ``litellm.Router``, + so the resolver has to apply at deployment construction time too.""" + from app.services.vision_llm_router_service import VisionLLMRouterService + + deployment = VisionLLMRouterService._config_to_deployment( + { + "model_name": "openai/gpt-4o", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py new file mode 100644 index 000000000..c317eba20 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py @@ -0,0 +1,526 @@ +"""Unit tests for ``AssistantContentBuilder``. + +Pins the in-memory ``ContentPart[]`` projection so the JSONB the server +persists matches what the frontend renders live (see +``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both +the structural shape of ``snapshot()`` and that the snapshot is +``json.dumps``-safe (the streaming finally block writes it directly to +``new_chat_messages.content`` without an explicit serialization round +trip). +""" + +from __future__ import annotations + +import json + +import pytest + +from app.tasks.chat.content_builder import AssistantContentBuilder + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _assert_jsonb_safe(parts: list[dict]) -> None: + """Sanity check: any snapshot must round-trip through ``json.dumps``.""" + serialized = json.dumps(parts) + assert json.loads(serialized) == parts + + +# --------------------------------------------------------------------------- +# Text turns +# --------------------------------------------------------------------------- + + +class TestTextOnly: + def test_single_text_block_collapses_consecutive_deltas(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "Hello") + b.on_text_delta("text-1", " ") + b.on_text_delta("text-1", "world") + b.on_text_end("text-1") + + snap = b.snapshot() + assert snap == [{"type": "text", "text": "Hello world"}] + assert not b.is_empty() + _assert_jsonb_safe(snap) + + def test_empty_text_start_end_pair_leaves_no_part(self): + # Mirrors the FE: a text-start without any deltas should + # not materialise an empty ``{"type":"text","text":""}`` part. + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_end("text-1") + + assert b.snapshot() == [] + assert b.is_empty() + + def test_text_after_text_end_starts_fresh_part(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "first") + b.on_text_end("text-1") + + b.on_text_start("text-2") + b.on_text_delta("text-2", "second") + b.on_text_end("text-2") + + snap = b.snapshot() + assert snap == [ + {"type": "text", "text": "first"}, + {"type": "text", "text": "second"}, + ] + + +class TestReasoningThenText: + def test_reasoning_followed_by_text_yields_two_parts_in_order(self): + b = AssistantContentBuilder() + b.on_reasoning_start("r-1") + b.on_reasoning_delta("r-1", "Considering options...") + b.on_reasoning_end("r-1") + + b.on_text_start("text-1") + b.on_text_delta("text-1", "The answer is 42.") + b.on_text_end("text-1") + + snap = b.snapshot() + assert snap == [ + {"type": "reasoning", "text": "Considering options..."}, + {"type": "text", "text": "The answer is 42."}, + ] + _assert_jsonb_safe(snap) + + def test_text_delta_after_reasoning_implicitly_closes_reasoning(self): + # Mirrors FE ``appendText``: a text delta arriving while a + # reasoning part is "active" still produces a fresh text + # part, never appends into the reasoning block. + b = AssistantContentBuilder() + b.on_reasoning_start("r-1") + b.on_reasoning_delta("r-1", "thinking") + # No explicit reasoning_end — text delta should close it. + b.on_text_delta("text-1", "answer") + + snap = b.snapshot() + assert snap == [ + {"type": "reasoning", "text": "thinking"}, + {"type": "text", "text": "answer"}, + ] + + +# --------------------------------------------------------------------------- +# Tool calls +# --------------------------------------------------------------------------- + + +class TestToolHeavyTurn: + def test_full_tool_lifecycle_produces_complete_tool_call_part(self): + b = AssistantContentBuilder() + # Some narration before the tool fires. + b.on_text_start("text-1") + b.on_text_delta("text-1", "Searching...") + b.on_text_end("text-1") + + b.on_tool_input_start( + ui_id="call_run123", + tool_name="web_search", + langchain_tool_call_id="lc_tool_abc", + ) + b.on_tool_input_delta("call_run123", '{"query":') + b.on_tool_input_delta("call_run123", '"surfsense"}') + b.on_tool_input_available( + ui_id="call_run123", + tool_name="web_search", + args={"query": "surfsense"}, + langchain_tool_call_id="lc_tool_abc", + ) + b.on_tool_output_available( + ui_id="call_run123", + output={"status": "completed", "citations": {}}, + langchain_tool_call_id="lc_tool_abc", + ) + + snap = b.snapshot() + assert snap[0] == {"type": "text", "text": "Searching..."} + tool_part = snap[1] + assert tool_part["type"] == "tool-call" + assert tool_part["toolCallId"] == "call_run123" + assert tool_part["toolName"] == "web_search" + assert tool_part["args"] == {"query": "surfsense"} + # ``argsText`` is the pretty-printed final JSON, not the raw + # streaming buffer (FE ``stream-pipeline.ts:128``). + assert tool_part["argsText"] == json.dumps( + {"query": "surfsense"}, indent=2, ensure_ascii=False + ) + assert tool_part["langchainToolCallId"] == "lc_tool_abc" + assert tool_part["result"] == {"status": "completed", "citations": {}} + _assert_jsonb_safe(snap) + + def test_tool_input_available_without_prior_start_creates_card(self): + # Legacy / parity_v2-OFF path: tool-input-available may be + # emitted without a prior tool-input-start (no streamed + # tool_call_chunks). The card should still be created. + b = AssistantContentBuilder() + b.on_tool_input_available( + ui_id="call_run42", + tool_name="grep", + args={"pattern": "TODO"}, + langchain_tool_call_id="lc_x", + ) + b.on_tool_output_available( + ui_id="call_run42", + output={"matches": 3}, + langchain_tool_call_id="lc_x", + ) + + snap = b.snapshot() + assert len(snap) == 1 + part = snap[0] + assert part["type"] == "tool-call" + assert part["toolCallId"] == "call_run42" + assert part["args"] == {"pattern": "TODO"} + assert part["langchainToolCallId"] == "lc_x" + assert part["result"] == {"matches": 3} + + def test_tool_input_start_idempotent_for_same_ui_id(self): + # parity_v2: tool-input-start can fire from BOTH the chunk + # registration path AND the canonical ``on_tool_start`` path. + # The second call must not create a duplicate part. + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", "lc_x") + b.on_tool_input_start("call_x", "ls", "lc_x") + snap = b.snapshot() + assert len(snap) == 1 + + def test_tool_input_delta_without_prior_start_is_silently_dropped(self): + b = AssistantContentBuilder() + b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}') + assert b.snapshot() == [] + + def test_langchain_tool_call_id_backfills_only_when_absent(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", "lc_first") + # Late event must NOT clobber an already-set lc id. + b.on_tool_input_start("call_x", "ls", "lc_late") + snap = b.snapshot() + assert snap[0]["langchainToolCallId"] == "lc_first" + + def test_args_text_streaming_buffer_reflects_concatenation(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "save_doc", "lc_y") + b.on_tool_input_delta("call_x", '{"title":') + b.on_tool_input_delta("call_x", '"Hi"}') + # Snapshot mid-stream should see the partial buffer (the FE + # tolerates invalid JSON and renders it as-is). + mid = b.snapshot() + assert mid[0]["argsText"] == '{"title":"Hi"}' + # Then tool-input-available replaces with pretty-printed. + b.on_tool_input_available( + "call_x", + "save_doc", + {"title": "Hi"}, + "lc_y", + ) + final = b.snapshot() + assert final[0]["argsText"] == json.dumps( + {"title": "Hi"}, indent=2, ensure_ascii=False + ) + + +# --------------------------------------------------------------------------- +# Thinking steps & separators +# --------------------------------------------------------------------------- + + +class TestThinkingSteps: + def test_first_thinking_step_unshifts_singleton_to_index_zero(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "Hello") + b.on_text_end("text-1") + + b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"]) + + snap = b.snapshot() + # Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift). + assert snap[0]["type"] == "data-thinking-steps" + assert snap[0]["data"]["steps"] == [ + { + "id": "step-1", + "title": "Analyzing", + "status": "in_progress", + "items": ["item-a"], + } + ] + assert snap[1] == {"type": "text", "text": "Hello"} + + def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self): + b = AssistantContentBuilder() + b.on_thinking_step("step-1", "Analyzing", "in_progress", []) + b.on_thinking_step("step-2", "Searching", "in_progress", ["q"]) + b.on_thinking_step("step-1", "Analyzing", "completed", ["done"]) + + snap = b.snapshot() + assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1 + steps = snap[0]["data"]["steps"] + assert len(steps) == 2 + assert steps[0]["id"] == "step-1" + assert steps[0]["status"] == "completed" + assert steps[0]["items"] == ["done"] + assert steps[1]["id"] == "step-2" + + def test_thinking_step_with_text_continues_appending_to_text(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "first") + + # Thinking step inserts at index 0, bumps text idx from 0 to 1. + b.on_thinking_step("step-1", "Working", "in_progress", []) + b.on_text_delta("text-1", " second") + + snap = b.snapshot() + text_parts = [p for p in snap if p["type"] == "text"] + assert text_parts == [{"type": "text", "text": "first second"}] + + def test_thinking_step_without_id_is_dropped(self): + b = AssistantContentBuilder() + b.on_thinking_step("", "noop", "in_progress", None) + assert b.snapshot() == [] + assert b.is_empty() + + +class TestStepSeparators: + def test_separator_no_op_before_any_content(self): + b = AssistantContentBuilder() + b.on_step_separator() + assert b.snapshot() == [] + + def test_separator_after_text_appends_with_step_index_zero(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "first") + b.on_text_end("text-1") + + b.on_step_separator() + + snap = b.snapshot() + assert snap[-1] == { + "type": "data-step-separator", + "data": {"stepIndex": 0}, + } + + def test_consecutive_separators_collapse_to_one(self): + b = AssistantContentBuilder() + b.on_text_delta("text-1", "x") + b.on_step_separator() + b.on_step_separator() # No-op: previous part is already a separator. + snap = b.snapshot() + assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1 + + def test_step_index_increments_across_separators(self): + b = AssistantContentBuilder() + b.on_text_delta("text-1", "a") + b.on_step_separator() + b.on_text_delta("text-2", "b") + b.on_step_separator() + snap = b.snapshot() + seps = [p for p in snap if p["type"] == "data-step-separator"] + assert [s["data"]["stepIndex"] for s in seps] == [0, 1] + + +# --------------------------------------------------------------------------- +# Interruption handling +# --------------------------------------------------------------------------- + + +class TestMarkInterrupted: + def test_running_tool_calls_get_state_aborted(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_a", "ls", "lc_a") + b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a") + # No tool-output-available — simulates client disconnect mid-tool. + + b.mark_interrupted() + + snap = b.snapshot() + assert snap[0]["state"] == "aborted" + assert "result" not in snap[0] + + def test_completed_tool_calls_are_not_marked_aborted(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_a", "ls", "lc_a") + b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a") + b.on_tool_output_available("call_a", {"files": []}, "lc_a") + + b.mark_interrupted() + + snap = b.snapshot() + assert "state" not in snap[0] + assert snap[0]["result"] == {"files": []} + + def test_open_text_block_keeps_accumulated_content(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "partial") + # No on_text_end — disconnect mid-stream. + + b.mark_interrupted() + + snap = b.snapshot() + assert snap == [{"type": "text", "text": "partial"}] + + +# --------------------------------------------------------------------------- +# is_empty / snapshot semantics +# --------------------------------------------------------------------------- + + +class TestIsEmpty: + def test_fresh_builder_is_empty(self): + assert AssistantContentBuilder().is_empty() + + def test_text_part_breaks_emptiness(self): + b = AssistantContentBuilder() + b.on_text_delta("text-1", "x") + assert not b.is_empty() + + def test_tool_call_breaks_emptiness(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", None) + assert not b.is_empty() + + def test_thinking_step_alone_does_not_break_emptiness(self): + # Mirrors the "status marker fallback" semantic: a turn that + # only emitted a thinking step before being interrupted should + # still be treated as empty for finalize_assistant_turn's + # status-marker substitution. + b = AssistantContentBuilder() + b.on_thinking_step("step-1", "Working", "in_progress", []) + assert b.is_empty() + + def test_step_separator_alone_does_not_break_emptiness(self): + b = AssistantContentBuilder() + # Force a separator (it would normally no-op without content, + # but we simulate the underlying state to verify is_empty is + # not fooled by a stray separator). + b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}}) + assert b.is_empty() + + +class TestSnapshotSemantics: + def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", "lc_x") + b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x") + snap = b.snapshot() + # Mutate the returned snapshot — original should be untouched. + snap[0]["args"]["mutated"] = True + snap[0]["state"] = "tampered" + + again = b.snapshot() + assert "mutated" not in again[0]["args"] + assert "state" not in again[0] + + def test_snapshot_round_trips_through_json(self): + b = AssistantContentBuilder() + b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"]) + b.on_text_delta("text-1", "answer") + b.on_tool_input_start("call_x", "ls", "lc_x") + b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x") + b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x") + b.on_step_separator() + snap = b.snapshot() + + encoded = json.dumps(snap) + assert json.loads(encoded) == snap + + +class TestStats: + """``stats()`` is the perf-log handle for [PERF] [stream_*] + finalize_payload lines. Pin the schema so an ops dashboard can + rely on these keys being present and meaningful. + """ + + def test_fresh_builder_reports_all_zeros(self): + b = AssistantContentBuilder() + s = b.stats() + assert s == { + "parts": 0, + "bytes": 2, # ``[]`` is two bytes + "text": 0, + "reasoning": 0, + "tool_calls": 0, + "tool_calls_completed": 0, + "tool_calls_aborted": 0, + "thinking_step_parts": 0, + "step_separators": 0, + } + + def test_counts_each_part_type_independently(self): + b = AssistantContentBuilder() + b.on_text_start("t1") + b.on_text_delta("t1", "hi") + b.on_text_end("t1") + b.on_reasoning_start("r1") + b.on_reasoning_delta("r1", "thinking") + b.on_reasoning_end("r1") + b.on_thinking_step("step-1", "Analyzing", "completed", ["item"]) + b.on_step_separator() + b.on_tool_input_start("call_done", "ls", "lc_done") + b.on_tool_input_available("call_done", "ls", {}, "lc_done") + b.on_tool_output_available("call_done", {"ok": True}, "lc_done") + b.on_tool_input_start("call_running", "rm", "lc_running") + b.on_tool_input_available("call_running", "rm", {}, "lc_running") + + s = b.stats() + assert s["text"] == 1 + assert s["reasoning"] == 1 + assert s["tool_calls"] == 2 + assert s["tool_calls_completed"] == 1 + assert s["tool_calls_aborted"] == 0 + assert s["thinking_step_parts"] == 1 + assert s["step_separators"] == 1 + assert s["parts"] == sum( + [ + s["text"], + s["reasoning"], + s["tool_calls"], + s["thinking_step_parts"], + s["step_separators"], + ] + ) + assert s["bytes"] > 0 + + def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_done", "ls", "lc_done") + b.on_tool_input_available("call_done", "ls", {}, "lc_done") + b.on_tool_output_available("call_done", {"ok": True}, "lc_done") + b.on_tool_input_start("call_running", "rm", "lc_running") + b.on_tool_input_available("call_running", "rm", {}, "lc_running") + + # Pre-interrupt: one completed, one still running (no result). + pre = b.stats() + assert pre["tool_calls_completed"] == 1 + assert pre["tool_calls_aborted"] == 0 + + b.mark_interrupted() + post = b.stats() + assert post["tool_calls_completed"] == 1 + assert post["tool_calls_aborted"] == 1 + assert post["tool_calls"] == 2 + + def test_bytes_reflects_jsonb_payload_size(self): + # Each text-delta adds bytes monotonically — useful for catching + # an unbounded delta buffer regression in the perf signal. + b = AssistantContentBuilder() + b.on_text_start("t1") + b.on_text_delta("t1", "x" * 10) + small = b.stats()["bytes"] + b.on_text_delta("t1", "x" * 1000) + large = b.stats()["bytes"] + assert large > small + 900 diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index 9258d5cfe..60750396c 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -51,22 +51,34 @@ class _FakeToolMessage: tool_call_id: str | None = None +@dataclass +class _FakeInterrupt: + value: dict[str, Any] + + +@dataclass +class _FakeTask: + interrupts: tuple[_FakeInterrupt, ...] = () + + class _FakeAgentState: """Stand-in for ``StateSnapshot`` returned by ``aget_state``.""" - def __init__(self) -> None: + def __init__(self, tasks: list[Any] | None = None) -> None: # Empty values keeps the cloud-fallback safety-net branch a no-op, - # and an empty ``tasks`` list keeps the post-stream interrupt - # check a no-op too. + # and empty ``tasks`` keep the post-stream interrupt check a no-op too. self.values: dict[str, Any] = {} - self.tasks: list[Any] = [] + self.tasks: list[Any] = tasks or [] class _FakeAgent: """Replays a list of ``astream_events`` events.""" - def __init__(self, events: list[dict[str, Any]]) -> None: + def __init__( + self, events: list[dict[str, Any]], state: _FakeAgentState | None = None + ) -> None: self._events = events + self._state = state or _FakeAgentState() async def astream_events( # type: ignore[no-untyped-def] self, _input_data: Any, *, config: dict[str, Any], version: str @@ -79,7 +91,7 @@ class _FakeAgent: # Called once after astream_events drains so the cloud-fallback # safety net can inspect staged filesystem work. The fake stays # empty so the safety net is a no-op. - return _FakeAgentState() + return self._state def _model_stream( @@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: ) -async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]: +async def _drain( + events: list[dict[str, Any]], state: _FakeAgentState | None = None +) -> list[dict[str, Any]]: """Run ``_stream_agent_events`` against a fake agent and return the SSE payloads (parsed JSON) it yielded. """ - agent = _FakeAgent(events) + agent = _FakeAgent(events, state=state) service = VercelStreamingService() result = StreamResult() config = {"configurable": {"thread_id": "test-thread"}} @@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id( assert len(starts) == 1 assert starts[0]["toolCallId"].startswith("call_run-1") assert starts[0]["langchainToolCallId"] == "lc-orphan" + + +@pytest.mark.asyncio +async def test_interrupt_request_uses_task_that_contains_interrupt( + parity_v2_on: None, +) -> None: + interrupt_payload = { + "type": "calendar_event_create", + "action": { + "tool": "create_calendar_event", + "params": {"summary": "mom bday"}, + }, + "context": {}, + } + state = _FakeAgentState( + tasks=[ + _FakeTask(interrupts=()), + _FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)), + ] + ) + + payloads = await _drain([], state=state) + + interrupts = _of_type(payloads, "data-interrupt-request") + assert len(interrupts) == 1 + assert ( + interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event" + ) diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py new file mode 100644 index 000000000..a5bb3f58a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -0,0 +1,318 @@ +"""Regression tests for ``run_async_celery_task``. + +These tests pin down the production bug observed on 2026-05-02 where +the video-presentation Celery task hung at ``[billable_call] finalize`` +because the shared ``app.db.engine`` had pooled asyncpg connections +bound to a *previous* task's now-closed event loop. Reusing such a +connection on a fresh loop crashes inside ``pool_pre_ping`` with:: + + AttributeError: 'NoneType' object has no attribute 'send' + +(the proactor is None because the loop is gone) and can hang forever +inside the asyncpg ``Connection._cancel`` cleanup coroutine. + +The fix is ``run_async_celery_task``: a small helper that runs every +async celery task body inside a fresh event loop and disposes the +shared engine pool both before (defends against a previous task that +crashed) and after (releases connections we opened on this loop). + +Tests here exercise the helper with a stub engine that records +``dispose()`` calls and panics if a coroutine produced by one loop is +awaited on another — mirroring the real asyncpg behaviour. +""" + +from __future__ import annotations + +import asyncio +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from unittest.mock import patch + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Stub engine that emulates the asyncpg-on-stale-loop crash +# --------------------------------------------------------------------------- + + +class _StaleLoopEngine: + """Tiny stand-in for ``app.db.engine`` that tracks dispose() calls. + + ``dispose()`` is async (matches ``AsyncEngine.dispose``) and records + the running event loop id so tests can assert it ran on *each* + fresh loop. + """ + + def __init__(self) -> None: + self.dispose_loop_ids: list[int] = [] + + async def dispose(self) -> None: + loop = asyncio.get_running_loop() + self.dispose_loop_ids.append(id(loop)) + + +@contextmanager +def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]: + """Patch ``from app.db import engine as shared_engine`` lookup. + + The helper imports lazily inside the function body, so we have to + patch the attribute on the already-loaded ``app.db`` module. + """ + import app.db as app_db + + original = getattr(app_db, "engine", None) + app_db.engine = stub # type: ignore[attr-defined] + try: + yield + finally: + if original is None: + with pytest.raises(AttributeError): + _ = app_db.engine + else: + app_db.engine = original # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_runner_returns_value_and_disposes_engine_around_call() -> None: + """Happy path: the coroutine result is returned, and the shared + engine is disposed both before and after the task body runs. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> str: + # Engine should already have been disposed once before we run. + assert len(stub.dispose_loop_ids) == 1 + return "ok" + + with _patch_shared_engine(stub): + result = run_async_celery_task(_body) + + assert result == "ok" + # Once before the body, once after (in finally). + assert len(stub.dispose_loop_ids) == 2 + # Both disposes ran on the SAME (fresh) loop the task body used. + assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1] + + +def test_runner_creates_fresh_loop_per_invocation() -> None: + """Each call must spin its own loop. Without this guarantee a + previous task's loop would be reused and the asyncpg-stale-loop + crash would never be avoided. + """ + import app.tasks.celery_tasks as celery_tasks_pkg + + stub = _StaleLoopEngine() + new_loop_calls = 0 + closed_loops: list[bool] = [] + + real_new_event_loop = asyncio.new_event_loop + + def _counting_new_loop() -> asyncio.AbstractEventLoop: + nonlocal new_loop_calls + new_loop_calls += 1 + loop = real_new_event_loop() + # Hook close() so we can verify each loop was closed properly + # before the next one was created. + original_close = loop.close + + def _tracked_close() -> None: + closed_loops.append(True) + original_close() + + loop.close = _tracked_close # type: ignore[method-assign] + return loop + + async def _body() -> None: + # Loop is alive and current at body execution time. + running = asyncio.get_running_loop() + assert not running.is_closed() + + with ( + _patch_shared_engine(stub), + patch.object(asyncio, "new_event_loop", _counting_new_loop), + ): + for _ in range(3): + celery_tasks_pkg.run_async_celery_task(_body) + + assert new_loop_calls == 3 + assert closed_loops == [True, True, True] + # Each invocation disposed twice (before + after). + assert len(stub.dispose_loop_ids) == 6 + + +def test_runner_disposes_engine_even_when_body_raises() -> None: + """Cleanup MUST run on the failure path too — otherwise stale + connections leak into the next task and cause the original hang. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + class _BoomError(RuntimeError): + pass + + async def _body() -> None: + raise _BoomError("kaboom") + + with _patch_shared_engine(stub), pytest.raises(_BoomError): + run_async_celery_task(_body) + + assert len(stub.dispose_loop_ids) == 2 # before + after still ran + + +def test_runner_swallows_dispose_errors() -> None: + """A flaky engine.dispose() must NEVER take down a celery task. + + Production scenario: the very first dispose (before the body runs) + might hit a partially-initialised engine; the helper logs and + moves on. The task body still runs; the result is still returned. + """ + from app.tasks.celery_tasks import run_async_celery_task + + class _AngryEngine: + def __init__(self) -> None: + self.calls = 0 + + async def dispose(self) -> None: + self.calls += 1 + raise RuntimeError("dispose() blew up") + + stub = _AngryEngine() + + async def _body() -> int: + return 42 + + with _patch_shared_engine(stub): + assert run_async_celery_task(_body) == 42 + + assert stub.calls == 2 # before + after both attempted + + +def test_runner_propagates_value_from_async_body() -> None: + """Sanity: pass-through of any pickleable celery return value.""" + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> dict[str, object]: + return {"status": "ready", "video_presentation_id": 19} + + with _patch_shared_engine(stub): + out = run_async_celery_task(_body) + + assert out == {"status": "ready", "video_presentation_id": 19} + + +def test_video_presentation_task_uses_runner_helper() -> None: + """Defence-in-depth: confirm the celery task module imports + ``run_async_celery_task``. If a future refactor inlines a + ``loop = asyncio.new_event_loop(); ... loop.close()`` block again, + the original hang will return. + """ + # The module's task body should not contain a manual new_event_loop + # call — that's exactly what the helper exists to centralise. + import inspect + + from app.tasks.celery_tasks import video_presentation_tasks + + src = inspect.getsource(video_presentation_tasks) + assert "run_async_celery_task" in src, ( + "video_presentation_tasks.py must use run_async_celery_task; " + "manual asyncio.new_event_loop() in a celery task hangs on the " + "shared SQLAlchemy pool when reused across tasks." + ) + assert "asyncio.new_event_loop" not in src, ( + "video_presentation_tasks.py contains a raw asyncio.new_event_loop " + "call — route every async task through run_async_celery_task to " + "avoid the stale-pool hang." + ) + + +def test_podcast_task_uses_runner_helper() -> None: + """Symmetric assertion for the podcast task — same root cause, same + fix, same regression risk. + """ + import inspect + + from app.tasks.celery_tasks import podcast_tasks + + src = inspect.getsource(podcast_tasks) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src + + +def test_runner_runs_shutdown_asyncgens_before_close() -> None: + """If the task body created any async generators that didn't get + fully iterated, we must still call ``loop.shutdown_asyncgens()`` + before closing — otherwise we leak event-loop bound resources + that re-emerge as ``RuntimeError: Event loop is closed`` later. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _agen(): + try: + yield 1 + yield 2 + finally: + pass + + async def _body() -> None: + # Iterate the agen partially, then leave it dangling — exactly + # the situation shutdown_asyncgens() is designed to clean up. + async for v in _agen(): + if v == 1: + break + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # By the time the helper returns, garbage collection + shutdown_asyncgens + # should have ensured no live async-gen references remain. We don't + # assert agen.closed directly (it depends on GC ordering); the real + # contract is "no warnings, no event-loop-closed errors". A successful + # second invocation proves the loop was cleaned up properly. + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # Force a GC pass to surface any 'coroutine was never awaited' + # warnings that would indicate the cleanup is broken. + gc.collect() + + +def test_runner_uses_proactor_loop_on_windows() -> None: + """On Windows the celery worker preselects a Proactor policy so + subprocess (ffmpeg) calls work. The helper must not silently fall + back to a Selector loop and re-break video/podcast generation. + """ + if not sys.platform.startswith("win"): + pytest.skip("Windows-specific event-loop policy assertion") + + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + # Mirror the policy set at the top of every Windows celery task. + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + observed: list[str] = [] + + async def _body() -> None: + observed.append(type(asyncio.get_running_loop()).__name__) + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + assert observed == ["ProactorEventLoop"] diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py new file mode 100644 index 000000000..699297df1 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -0,0 +1,388 @@ +"""Unit tests for podcast Celery task billing integration. + +Validates ``_generate_content_podcast`` correctly wraps +``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the +search-space owner's billing decision, and degrades cleanly when the +resolver fails or premium credit is exhausted. + +Coverage: + +* Happy-path free config: resolver → ``billable_call`` enters with + ``usage_type='podcast_generation'`` and the configured reserve override, + graph runs, podcast row flips to ``READY``. +* Happy-path premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → + graph is *not* invoked, podcast row flips to ``FAILED``, return dict + carries ``reason='premium_quota_exhausted'``. +* Resolver failure: ``ValueError`` from the resolver → podcast row flips + to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, podcast): + self._podcast = podcast + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._podcast) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: + """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily + inside helpers keeps this fixture cheap.""" + return SimpleNamespace( + id=podcast_id, + title="Test Podcast", + thread_id=thread_id, + status=None, + podcast_transcript=None, + file_location=None, + ) + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + """Stand-in for ``billable_call`` that records its kwargs and yields a + no-op accumulator-shaped object.""" + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover — for grammar only + + +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + """Happy path: free billing tier still wraps the graph call so the + audit row is recorded. Verifies kwargs threading.""" + from app.config import config as app_config + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=7, thread_id=99) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 555 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return { + "podcast_transcript": [ + SimpleNamespace(speaker_id=0, dialog="Hi"), + SimpleNamespace(speaker_id=1, dialog="Hello"), + ], + "final_podcast_file_path": "/tmp/podcast.wav", + } + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hello world", + search_space_id=555, + user_prompt="make it short", + ) + + assert result["status"] == "ready" + assert result["podcast_id"] == 7 + assert podcast.status == PodcastStatus.READY + assert podcast.file_location == "/tmp/podcast.wav" + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 555 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "podcast_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS + ) + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "podcast_id": 7, + "title": "Test Podcast", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + """Premium resolution flows through to ``billable_call`` so the + reserve/finalize path triggers.""" + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast() + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): + """When ``billable_call`` denies the reservation, the graph never + runs and the podcast row flips to FAILED with the documented reason + code.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=8) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=8, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 8, + "reason": "premium_quota_exhausted", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] # Graph never ran on denied reservation. + + +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=10) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr( + podcast_tasks, "billable_call", _settlement_failing_billable_call + ) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=10, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 10, + "reason": "billing_settlement_failed", + } + assert podcast.status == PodcastStatus.FAILED + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_podcast_failed(monkeypatch): + """If the resolver raises (e.g. search-space deleted), the task fails + cleanly without invoking the graph.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=9) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 555 not found") + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=9, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 9, + "reason": "billing_resolution_failed", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py new file mode 100644 index 000000000..792d059b0 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -0,0 +1,119 @@ +"""Predicate-level test for the chat streaming safety net. + +The safety net in ``stream_new_chat`` rejects an image turn early with +a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the +selected model is *known* to be text-only. The earlier round of this +work used a strict opt-in flag (``supports_image_input`` defaulting to +False on every YAML entry) which blocked vision-capable Azure GPT-5.x +deployments — this is the regression we're fixing. + +The new predicate is :func:`is_known_text_only_chat_model`, which +returns True only when LiteLLM's authoritative model map *explicitly* +sets ``supports_vision=False``. Anything else (vision True, missing +key, exception) returns False so the request flows through to the +provider. + +We exercise the predicate directly here rather than driving the full +``stream_new_chat`` generator — covering the gate in isolation keeps +the test focused on the regression while the generator's wider behavior +is exercised by the integration suite. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import is_known_text_only_chat_model + +pytestmark = pytest.mark.unit + + +def test_safety_net_does_not_fire_for_azure_gpt_4o(): + """Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is + vision-capable per LiteLLM's model map. The previous round's + blanket-False default blocked it; the new predicate must NOT mark + it text-only.""" + assert ( + is_known_text_only_chat_model( + provider="AZURE_OPENAI", + model_name="my-azure-deployment", + base_model="gpt-4o", + ) + is False + ) + + +def test_safety_net_does_not_fire_for_unknown_model(): + """Default-pass on unknown — the safety net only blocks definitive + text-only confirmations. A freshly added third-party model that + LiteLLM doesn't know about must flow through to the provider.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + custom_provider="brand_new_proxy", + model_name="brand-new-model-x9", + ) + is False + ) + + +def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): + """Transient ``litellm.get_model_info`` exception ≠ block. The + helper swallows the error and treats it as 'unknown' → False.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise RuntimeError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_safety_net_fires_only_on_explicit_false(monkeypatch): + """Stub LiteLLM to assert the only path that returns True is the + explicit ``supports_vision=False`` case. Anything else (True, + None, missing key) returns False from the predicate.""" + import app.services.provider_capabilities as pc + + def _info_explicit_false(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="text-only-stub", + ) + is True + ) + + def _info_true(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="vision-stub", + ) + is False + ) + + def _info_missing(**_kwargs): + return {"max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="missing-key-stub", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py new file mode 100644 index 000000000..423b64ddb --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -0,0 +1,398 @@ +"""Unit tests for video-presentation Celery task billing integration. + +Mirrors ``test_podcast_billing.py`` for the video-presentation task. +Validates the same wrap-graph-in-billable_call pattern and ensures the +larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is +threaded through. + +Coverage: + +* Free config: graph runs, ``billable_call`` invoked with the video + reserve override. +* Premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: graph not invoked, row → FAILED, reason code surfaced. +* Resolver failure: row → FAILED with ``billing_resolution_failed``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, video): + self._video = video + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._video) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace: + return SimpleNamespace( + id=video_id, + title="Test Presentation", + thread_id=thread_id, + status=None, + slides=None, + scene_codes=None, + ) + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover + + +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + from app.config import config as app_config + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=11, thread_id=99) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 777 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result["status"] == "ready" + assert result["video_presentation_id"] == 11 + assert video.status == VideoPresentationStatus.READY + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 777 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "video_presentation_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS + ) + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "video_presentation_id": 11, + "title": "Test Presentation", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video() + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=12) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, "billable_call", _denying_billable_call + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=12, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 12, + "reason": "premium_quota_exhausted", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] + + +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=14) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, + "billable_call", + _settlement_failing_billable_call, + ) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=14, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 14, + "reason": "billing_settlement_failed", + } + assert video.status == VideoPresentationStatus.FAILED + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=13) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 777 not found") + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _failing_resolver, + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=13, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 13, + "reason": "billing_resolution_failed", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 034aa484c..208204ca9 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -1,9 +1,21 @@ +import inspect +import json +import logging +import re +from pathlib import Path + import pytest +import app.tasks.chat.stream_new_chat as stream_new_chat_module +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel from app.tasks.chat.stream_new_chat import ( StreamResult, + _classify_stream_exception, _contract_enforcement_active, _evaluate_file_contract_outcome, + _extract_resolved_file_path, + _log_chat_stream_error, _tool_output_has_error, ) @@ -17,6 +29,39 @@ def test_tool_output_error_detection(): assert not _tool_output_has_error({"result": "Updated file /notes.md"}) +def test_extract_resolved_file_path_prefers_structured_path(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"status": "completed", "path": "/docs/note.md"}, + tool_input=None, + ) + == "/docs/note.md" + ) + + +def test_extract_resolved_file_path_falls_back_to_tool_input(): + assert ( + _extract_resolved_file_path( + tool_name="edit_file", + tool_output={"status": "completed", "result": "updated"}, + tool_input={"file_path": "/docs/edited.md"}, + ) + == "/docs/edited.md" + ) + + +def test_extract_resolved_file_path_does_not_parse_result_text(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"result": "Updated file /docs/from-text.md"}, + tool_input=None, + ) + is None + ) + + def test_file_write_contract_outcome_reasons(): result = StreamResult(intent_detected="file_write") passed, reason = _evaluate_file_contract_outcome(result) @@ -45,3 +90,507 @@ def test_contract_enforcement_local_only(): result.filesystem_mode = "cloud" assert not _contract_enforcement_active(result) + + +def _extract_chat_stream_payload(record_message: str) -> dict: + prefix = "[chat_stream_error] " + assert record_message.startswith(prefix) + return json.loads(record_message[len(prefix) :]) + + +def test_unified_chat_stream_error_log_schema(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="new", + error_kind="server_error", + error_code="SERVER_ERROR", + severity="warn", + is_expected=False, + request_id="req-123", + thread_id=101, + search_space_id=202, + user_id="user-1", + message="Error during chat: boom", + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + + required_keys = { + "event", + "flow", + "error_kind", + "error_code", + "severity", + "is_expected", + "request_id", + "thread_id", + "search_space_id", + "user_id", + "message", + } + assert required_keys.issubset(payload.keys()) + assert payload["event"] == "chat_stream_error" + assert payload["flow"] == "new" + assert payload["error_code"] == "SERVER_ERROR" + + +def test_premium_quota_uses_unified_chat_stream_log_shape(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id="req-premium", + thread_id=303, + search_space_id=404, + user_id="user-2", + message="Buy more tokens to continue with this model, or switch to a free model", + extra={"auto_fallback": False}, + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + assert payload["event"] == "chat_stream_error" + assert payload["error_kind"] == "premium_quota_exhausted" + assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED" + assert payload["flow"] == "resume" + assert payload["is_expected"] is True + assert payload["auto_fallback"] is False + + +def test_stream_error_emission_keeps_machine_error_codes(): + source = inspect.getsource(stream_new_chat_module) + format_error_calls = re.findall(r"format_error\(", source) + emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) + + # All stream paths should route through one shared terminal error emitter. + assert len(format_error_calls) == 1 + assert { + "PREMIUM_QUOTA_EXHAUSTED", + "SERVER_ERROR", + }.issubset(emitted_error_codes) + assert 'flow: Literal["new", "regenerate"] = "new"' in source + assert "_emit_stream_terminal_error" in source + assert "flow=flow" in source + assert 'flow="resume"' in source + + +def test_stream_exception_classifies_rate_limited(): + exc = Exception( + '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' + ) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + assert extra is None + + +def test_stream_exception_classifies_openrouter_429_payload(): + exc = Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' + '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' + ) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + assert extra is None + + +@pytest.mark.asyncio +async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch): + """``_preflight_llm`` is best-effort. + + - On rate-limit shaped exceptions (provider 429) it MUST re-raise so the + caller can drive the cooldown/repin branch. + - On any other transient failure it MUST swallow the error so the normal + stream path continues without surfacing preflight noise to the user. + """ + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + class _RateLimitedError(Exception): + """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers.""" + + rate_calls: list[dict] = [] + other_calls: list[dict] = [] + + async def _fake_acompletion_429(**kwargs): + rate_calls.append(kwargs) + raise _RateLimitedError("simulated 429") + + async def _fake_acompletion_other(**kwargs): + other_calls.append(kwargs) + raise RuntimeError("some unrelated transient failure") + + fake_llm = SimpleNamespace( + model="openrouter/google/gemma-4-31b-it:free", + api_key="test", + api_base=None, + ) + + import litellm # type: ignore[import-not-found] + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429) + with pytest.raises(_RateLimitedError): + await _preflight_llm(fake_llm) + assert len(rate_calls) == 1 + assert rate_calls[0]["max_tokens"] == 1 + assert rate_calls[0]["stream"] is False + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other) + # MUST NOT raise: non-rate-limit failures are swallowed. + await _preflight_llm(fake_llm) + assert len(other_calls) == 1 + + +@pytest.mark.asyncio +async def test_preflight_skipped_for_auto_router_model(): + """Router-mode ``model='auto'`` has no single deployment to ping; the + LiteLLM router itself owns per-deployment rate-limit accounting, so the + preflight helper must short-circuit instead of issuing a probe.""" + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None) + # Should return without raising or making any LiteLLM call. + await _preflight_llm(fake_llm) + + +@pytest.mark.asyncio +async def test_settle_speculative_agent_build_swallows_exceptions(): + """``_settle_speculative_agent_build`` MUST always return cleanly so the + caller can safely re-touch the request-scoped session afterwards. + + The helper guards the parallel preflight + agent-build path: when the + speculative build is being discarded (429 or non-429 preflight failure) + we await it solely to release any in-flight ``AsyncSession`` usage — + the build's outcome is irrelevant. Any exception (including + ``CancelledError``) leaking out would skip the caller's recovery flow + and re-introduce the very session-concurrency hazard the helper exists + to prevent. + """ + import asyncio + + from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build + + async def _raises() -> None: + raise RuntimeError("speculative build crashed") + + async def _succeeds() -> str: + return "agent" + + async def _slow() -> None: + await asyncio.sleep(0.05) + + for coro in (_raises(), _succeeds(), _slow()): + task = asyncio.create_task(coro) + await _settle_speculative_agent_build(task) + assert task.done() + + +@pytest.mark.asyncio +async def test_settle_speculative_agent_build_handles_already_done_task(): + """Done tasks (success or failure) must still be settled without raising.""" + import asyncio + + from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build + + async def _ok() -> str: + return "ok" + + async def _bad() -> None: + raise ValueError("nope") + + ok_task = asyncio.create_task(_ok()) + bad_task = asyncio.create_task(_bad()) + # Drive both to completion before settling. + await asyncio.sleep(0) + await asyncio.sleep(0) + + await _settle_speculative_agent_build(ok_task) + await _settle_speculative_agent_build(bad_task) + assert ok_task.result() == "ok" + # ``bad_task`` exception was consumed by the settle helper; calling + # ``.exception()`` after the fact must still return the original error + # (the helper observes it but doesn't clear it). + assert isinstance(bad_task.exception(), ValueError) + + +def test_stream_exception_classifies_thread_busy(): + exc = BusyError(request_id="thread-123") + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + assert extra is None + + +def test_stream_exception_classifies_thread_busy_from_message(): + exc = Exception("Thread is busy with another request") + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + assert extra is None + + +def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): + thread_id = "thread-cancelling-1" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "TURN_CANCELLING" + assert severity == "info" + assert is_expected is True + assert "stopping" in user_message + assert isinstance(extra, dict) + assert "retry_after_ms" in extra + + +def test_premium_classification_is_error_code_driven(): + classifier_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-error-classifier.ts" + ) + source = classifier_path.read_text(encoding="utf-8") + + assert "PREMIUM_KEYWORDS" not in source + assert "RATE_LIMIT_KEYWORDS" not in source + assert "normalized.includes(" not in source + assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source + + +def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + assert "onPreAcceptFailure?: () => Promise;" in source + assert "if (!accepted) {" in source + assert "await onPreAcceptFailure?.();" in source + assert "await onAcceptedStreamError?.();" in source + assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source + assert "setMessageDocumentsMap((prev) => {" in source + + +def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): + user_message_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/components/assistant-ui/user-message.tsx" + ) + source = user_message_path.read_text(encoding="utf-8") + + assert "Not sent. Edit and retry." not in source + assert "failed_pre_accept" not in source + + +def test_network_send_failures_use_unified_retry_toast_message(): + classifier_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-error-classifier.ts" + ) + classifier_source = classifier_path.read_text(encoding="utf-8") + request_errors_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-request-errors.ts" + ) + request_errors_source = request_errors_path.read_text(encoding="utf-8") + + assert '"send_failed_pre_accept"' in classifier_source + assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert 'errorCode === "TURN_CANCELLING"' in classifier_source + assert "if (withCode.code) return withCode.code;" in classifier_source + assert 'userMessage: "Message not sent. Please retry."' in classifier_source + assert 'userMessage: "Connection issue. Please try again."' in classifier_source + assert "const passthroughCodes = new Set([" in request_errors_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source + assert '"THREAD_BUSY"' in request_errors_source + assert '"TURN_CANCELLING"' in request_errors_source + assert '"AUTH_EXPIRED"' in request_errors_source + assert '"UNAUTHORIZED"' in request_errors_source + assert '"RATE_LIMITED"' in request_errors_source + assert '"NETWORK_ERROR"' in request_errors_source + assert '"STREAM_PARSE_ERROR"' in request_errors_source + assert '"TOOL_EXECUTION_ERROR"' in request_errors_source + assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source + assert '"SERVER_ERROR"' in request_errors_source + assert "passthroughCodes.has(existingCode)" in request_errors_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source + assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source + assert "Failed to start chat. Please try again." not in classifier_source + + +def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + # Each flow tracks accepted boundary and passes it into shared terminal handling. + # The acceptance boundary is still meaningful post-refactor: it gates + # local-state cleanup (onPreAcceptFailure path) and lets the shared + # terminal handler distinguish pre-accept aborts from in-stream errors. + assert "let newAccepted = false;" in source + assert "let resumeAccepted = false;" in source + assert "let regenerateAccepted = false;" in source + assert "accepted: newAccepted," in source + assert "accepted: resumeAccepted," in source + assert "accepted: regenerateAccepted," in source + + # NOTE: The FE-side persistence guards previously asserted here + # ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;", + # "if (newAccepted && !userPersisted) {") have been intentionally + # removed by the SSE-based message-id handshake refactor. Persistence + # is now server-authoritative: persist_user_turn / persist_assistant_shell + # run inside stream_new_chat / stream_resume_chat unconditionally and + # the FE consumes data-user-message-id / data-assistant-message-id + # SSE events to learn the canonical primary keys. There is therefore + # no FE call-site to guard, and the shared terminal handler relies + # purely on the `accepted` field above (forwarded to onAbort / + # onAcceptedStreamError) to drive UI cleanup. See + # tests/integration/chat/test_message_id_sse.py for the new + # cross-tier ID coherence guarantees. + + # The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent + # of the persistence refactor and must still exist on every + # start-stream fetch. + assert "const fetchWithTurnCancellingRetry = useCallback(" in source + assert "computeFallbackTurnCancellingRetryDelay" in source + assert 'withMeta.errorCode === "TURN_CANCELLING"' in source + assert 'withMeta.errorCode === "THREAD_BUSY"' in source + assert "await fetchWithTurnCancellingRetry(() =>" in source + + +def test_cancel_active_turn_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source + assert "response_model=CancelActiveTurnResponse" in source + assert 'status="cancelling",' in source + assert 'error_code="TURN_CANCELLING",' in source + assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source + assert "retry_after_at=" in source + assert 'status="idle",' in source + assert 'error_code="NO_ACTIVE_TURN",' in source + + +def test_turn_status_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source + assert "response_model=TurnStatusResponse" in source + assert "_build_turn_status_payload(thread_id)" in source + assert "Permission.CHATS_READ.value" in source + assert "_raise_if_thread_busy_for_start(" in source + + +def test_turn_cancelling_retry_policy_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source + assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source + assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source + assert "def _compute_turn_cancelling_retry_delay(" in source + assert "retry-after-ms" in source + assert '"Retry-After"' in source + assert '"errorCode": "TURN_CANCELLING"' in source + + +def test_turn_status_sse_contract_exists(): + stream_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/tasks/chat/stream_new_chat.py" + ).read_text(encoding="utf-8") + state_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/streaming-state.ts" + ).read_text(encoding="utf-8") + pipeline_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/stream-pipeline.ts" + ).read_text(encoding="utf-8") + + assert '"turn-status"' in stream_source + assert '"status": "busy"' in stream_source + assert '"status": "idle"' in stream_source + assert 'type: "data-turn-status"' in state_source + assert 'case "data-turn-status":' in pipeline_source + assert "end_turn(str(chat_id))" in stream_source + + +def test_chat_deepagent_forwards_resolved_model_name_to_both_builders(): + """Regression guard: both system-prompt builders in chat_deepagent.py + must receive ``model_name=_resolve_prompt_model_name(...)`` so the + provider-variant dispatch can render the right ```` + block. Without this the prompt silently falls back to the empty + ``"default"`` variant — the original bug being fixed. + + This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes` + in style: it inspects module source text + a regex to enforce the + call-site shape, not just the wrapper layer (the wrappers already + forward ``model_name`` correctly, so testing them would not catch + the actual missed plumbing). + """ + import app.agents.new_chat.chat_deepagent as chat_deepagent_module + + source = inspect.getsource(chat_deepagent_module) + + # Helper itself must be defined. + assert "def _resolve_prompt_model_name(" in source + + # Both builder calls must forward the resolved model name. Match + # across newlines + whitespace because the kwargs are split over + # multiple lines. + pattern = re.compile( + r"build_(?:surfsense|configurable)_system_prompt\([^)]*" + r"model_name=_resolve_prompt_model_name\(", + re.DOTALL, + ) + matches = pattern.findall(source) + assert len(matches) == 2, ( + "Expected both system-prompt builder call sites to forward " + "`model_name=_resolve_prompt_model_name(...)`, found " + f"{len(matches)}" + ) diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 209c42a9c..3e371cecc 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -62,7 +62,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.5" +version = "3.13.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -73,76 +73,76 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271 } +sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748 } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876 }, - { url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557 }, - { url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258 }, - { url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199 }, - { url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013 }, - { url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501 }, - { url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981 }, - { url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934 }, - { url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671 }, - { url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219 }, - { url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049 }, - { url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557 }, - { url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931 }, - { url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125 }, - { url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427 }, - { url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534 }, - { url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446 }, - { url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930 }, - { url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927 }, - { url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141 }, - { url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476 }, - { url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507 }, - { url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465 }, - { url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523 }, - { url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113 }, - { url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351 }, - { url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205 }, - { url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618 }, - { url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185 }, - { url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311 }, - { url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147 }, - { url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356 }, - { url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637 }, - { url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896 }, - { url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721 }, - { url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663 }, - { url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094 }, - { url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701 }, - { url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360 }, - { url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023 }, - { url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795 }, - { url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405 }, - { url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082 }, - { url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346 }, - { url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891 }, - { url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113 }, - { url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088 }, - { url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976 }, - { url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444 }, - { url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128 }, - { url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029 }, - { url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758 }, - { url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883 }, - { url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668 }, - { url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461 }, - { url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661 }, - { url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800 }, - { url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382 }, - { url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724 }, - { url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027 }, - { url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644 }, - { url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630 }, - { url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403 }, - { url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924 }, - { url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119 }, - { url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072 }, - { url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819 }, - { url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441 }, + { url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158 }, + { url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037 }, + { url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556 }, + { url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314 }, + { url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819 }, + { url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279 }, + { url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082 }, + { url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938 }, + { url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548 }, + { url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669 }, + { url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175 }, + { url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049 }, + { url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861 }, + { url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003 }, + { url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289 }, + { url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185 }, + { url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285 }, + { url = "https://files.pythonhosted.org/packages/e3/ac/892f4162df9b115b4758d615f32ec63d00f3084c705ff5526630887b9b42/aiohttp-3.13.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:63dd5e5b1e43b8fb1e91b79b7ceba1feba588b317d1edff385084fcc7a0a4538", size = 745744 }, + { url = "https://files.pythonhosted.org/packages/97/a9/c5b87e4443a2f0ea88cb3000c93a8fdad1ee63bffc9ded8d8c8e0d66efc6/aiohttp-3.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:746ac3cc00b5baea424dacddea3ec2c2702f9590de27d837aa67004db1eebc6e", size = 498178 }, + { url = "https://files.pythonhosted.org/packages/94/42/07e1b543a61250783650df13da8ddcdc0d0a5538b2bd15cef6e042aefc61/aiohttp-3.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bda8f16ea99d6a6705e5946732e48487a448be874e54a4f73d514660ff7c05d3", size = 498331 }, + { url = "https://files.pythonhosted.org/packages/20/d6/492f46bf0328534124772d0cf58570acae5b286ea25006900650f69dae0e/aiohttp-3.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b061e7b5f840391e3f64d0ddf672973e45c4cfff7a0feea425ea24e51530fc2", size = 1744414 }, + { url = "https://files.pythonhosted.org/packages/e2/4d/e02627b2683f68051246215d2d62b2d2f249ff7a285e7a858dc47d6b6a14/aiohttp-3.13.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b252e8d5cd66184b570d0d010de742736e8a4fab22c58299772b0c5a466d4b21", size = 1719226 }, + { url = "https://files.pythonhosted.org/packages/7b/6c/5d0a3394dd2b9f9aeba6e1b6065d0439e4b75d41f1fb09a3ec010b43552b/aiohttp-3.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20af8aad61d1803ff11152a26146d8d81c266aa8c5aa9b4504432abb965c36a0", size = 1782110 }, + { url = "https://files.pythonhosted.org/packages/0d/2d/c20791e3437700a7441a7edfb59731150322424f5aadf635602d1d326101/aiohttp-3.13.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:13a5cc924b59859ad2adb1478e31f410a7ed46e92a2a619d6d1dd1a63c1a855e", size = 1884809 }, + { url = "https://files.pythonhosted.org/packages/c8/94/d99dbfbd1924a87ef643833932eb2a3d9e5eee87656efea7d78058539eff/aiohttp-3.13.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:534913dfb0a644d537aebb4123e7d466d94e3be5549205e6a31f72368980a81a", size = 1764938 }, + { url = "https://files.pythonhosted.org/packages/49/61/3ce326a1538781deb89f6cf5e094e2029cd308ed1e21b2ba2278b08426f6/aiohttp-3.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:320e40192a2dcc1cf4b5576936e9652981ab596bf81eb309535db7e2f5b5672f", size = 1570697 }, + { url = "https://files.pythonhosted.org/packages/b6/77/4ab5a546857bb3028fbaf34d6eea180267bdab022ee8b1168b1fcde4bfdd/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9e587fcfce2bcf06526a43cb705bdee21ac089096f2e271d75de9c339db3100c", size = 1702258 }, + { url = "https://files.pythonhosted.org/packages/79/63/d8f29021e39bc5af8e5d5e9da1b07976fb9846487a784e11e4f4eeda4666/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9eb9c2eea7278206b5c6c1441fdd9dc420c278ead3f3b2cc87f9b693698cc500", size = 1740287 }, + { url = "https://files.pythonhosted.org/packages/55/3a/cbc6b3b124859a11bc8055d3682c26999b393531ef926754a3445b99dfef/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:29be00c51972b04bf9d5c8f2d7f7314f48f96070ca40a873a53056e652e805f7", size = 1753011 }, + { url = "https://files.pythonhosted.org/packages/e0/30/836278675205d58c1368b21520eab9572457cf19afd23759216c04483048/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:90c06228a6c3a7c9f776fe4fc0b7ff647fffd3bed93779a6913c804ae00c1073", size = 1566359 }, + { url = "https://files.pythonhosted.org/packages/50/b4/8032cc9b82d17e4277704ba30509eaccb39329dc18d6a35f05e424439e32/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:a533ec132f05fd9a1d959e7f34184cd7d5e8511584848dab85faefbaac573069", size = 1785537 }, + { url = "https://files.pythonhosted.org/packages/17/7d/5873e98230bde59f493bf1f7c3e327486a4b5653fa401144704df5d00211/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1c946f10f413836f82ea4cfb90200d2a59578c549f00857e03111cf45ad01ca5", size = 1740752 }, + { url = "https://files.pythonhosted.org/packages/7b/f2/13e46e0df051494d7d3c68b7f72d071f48c384c12716fc294f75d5b1a064/aiohttp-3.13.4-cp313-cp313-win32.whl", hash = "sha256:48708e2706106da6967eff5908c78ca3943f005ed6bcb75da2a7e4da94ef8c70", size = 433187 }, + { url = "https://files.pythonhosted.org/packages/ea/c0/649856ee655a843c8f8664592cfccb73ac80ede6a8c8db33a25d810c12db/aiohttp-3.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:74a2eb058da44fa3a877a49e2095b591d4913308bb424c418b77beb160c55ce3", size = 459778 }, + { url = "https://files.pythonhosted.org/packages/6d/29/6657cc37ae04cacc2dbf53fb730a06b6091cc4cbe745028e047c53e6d840/aiohttp-3.13.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:e0a2c961fc92abeff61d6444f2ce6ad35bb982db9fc8ff8a47455beacf454a57", size = 749363 }, + { url = "https://files.pythonhosted.org/packages/90/7f/30ccdf67ca3d24b610067dc63d64dcb91e5d88e27667811640644aa4a85d/aiohttp-3.13.4-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:153274535985a0ff2bff1fb6c104ed547cec898a09213d21b0f791a44b14d933", size = 499317 }, + { url = "https://files.pythonhosted.org/packages/93/13/e372dd4e68ad04ee25dafb050c7f98b0d91ea643f7352757e87231102555/aiohttp-3.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:351f3171e2458da3d731ce83f9e6b9619e325c45cbd534c7759750cabf453ad7", size = 500477 }, + { url = "https://files.pythonhosted.org/packages/e5/fe/ee6298e8e586096fb6f5eddd31393d8544f33ae0792c71ecbb4c2bef98ac/aiohttp-3.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f989ac8bc5595ff761a5ccd32bdb0768a117f36dd1504b1c2c074ed5d3f4df9c", size = 1737227 }, + { url = "https://files.pythonhosted.org/packages/b0/b9/a7a0463a09e1a3fe35100f74324f23644bfc3383ac5fd5effe0722a5f0b7/aiohttp-3.13.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d36fc1709110ec1e87a229b201dd3ddc32aa01e98e7868083a794609b081c349", size = 1694036 }, + { url = "https://files.pythonhosted.org/packages/57/7c/8972ae3fb7be00a91aee6b644b2a6a909aedb2c425269a3bfd90115e6f8f/aiohttp-3.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42adaeea83cbdf069ab94f5103ce0787c21fb1a0153270da76b59d5578302329", size = 1786814 }, + { url = "https://files.pythonhosted.org/packages/93/01/c81e97e85c774decbaf0d577de7d848934e8166a3a14ad9f8aa5be329d28/aiohttp-3.13.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:92deb95469928cc41fd4b42a95d8012fa6df93f6b1c0a83af0ffbc4a5e218cde", size = 1866676 }, + { url = "https://files.pythonhosted.org/packages/5a/5f/5b46fe8694a639ddea2cd035bf5729e4677ea882cb251396637e2ef1590d/aiohttp-3.13.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0c7c07c4257ef3a1df355f840bc62d133bcdef5c1c5ba75add3c08553e2eed", size = 1740842 }, + { url = "https://files.pythonhosted.org/packages/20/a2/0d4b03d011cca6b6b0acba8433193c1e484efa8d705ea58295590fe24203/aiohttp-3.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f062c45de8a1098cb137a1898819796a2491aec4e637a06b03f149315dff4d8f", size = 1566508 }, + { url = "https://files.pythonhosted.org/packages/98/17/e689fd500da52488ec5f889effd6404dece6a59de301e380f3c64f167beb/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:76093107c531517001114f0ebdb4f46858ce818590363e3e99a4a2280334454a", size = 1700569 }, + { url = "https://files.pythonhosted.org/packages/d8/0d/66402894dbcf470ef7db99449e436105ea862c24f7ea4c95c683e635af35/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:6f6ec32162d293b82f8b63a16edc80769662fbd5ae6fbd4936d3206a2c2cc63b", size = 1707407 }, + { url = "https://files.pythonhosted.org/packages/2f/eb/af0ab1a3650092cbd8e14ef29e4ab0209e1460e1c299996c3f8288b3f1ff/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5903e2db3d202a00ad9f0ec35a122c005e85d90c9836ab4cda628f01edf425e2", size = 1752214 }, + { url = "https://files.pythonhosted.org/packages/5a/bf/72326f8a98e4c666f292f03c385545963cc65e358835d2a7375037a97b57/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2d5bea57be7aca98dbbac8da046d99b5557c5cf4e28538c4c786313078aca09e", size = 1562162 }, + { url = "https://files.pythonhosted.org/packages/67/9f/13b72435f99151dd9a5469c96b3b5f86aa29b7e785ca7f35cf5e538f74c0/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:bcf0c9902085976edc0232b75006ef38f89686901249ce14226b6877f88464fb", size = 1768904 }, + { url = "https://files.pythonhosted.org/packages/18/bc/28d4970e7d5452ac7776cdb5431a1164a0d9cf8bd2fffd67b4fb463aa56d/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3295f98bfeed2e867cab588f2a146a9db37a85e3ae9062abf46ba062bd29165", size = 1723378 }, + { url = "https://files.pythonhosted.org/packages/53/74/b32458ca1a7f34d65bdee7aef2036adbe0438123d3d53e2b083c453c24dd/aiohttp-3.13.4-cp314-cp314-win32.whl", hash = "sha256:a598a5c5767e1369d8f5b08695cab1d8160040f796c4416af76fd773d229b3c9", size = 438711 }, + { url = "https://files.pythonhosted.org/packages/40/b2/54b487316c2df3e03a8f3435e9636f8a81a42a69d942164830d193beb56a/aiohttp-3.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:c555db4bc7a264bead5a7d63d92d41a1122fcd39cc62a4db815f45ad46f9c2c8", size = 464977 }, + { url = "https://files.pythonhosted.org/packages/47/fb/e41b63c6ce71b07a59243bb8f3b457ee0c3402a619acb9d2c0d21ef0e647/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45abbbf09a129825d13c18c7d3182fecd46d9da3cfc383756145394013604ac1", size = 781549 }, + { url = "https://files.pythonhosted.org/packages/97/53/532b8d28df1e17e44c4d9a9368b78dcb6bf0b51037522136eced13afa9e8/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:74c80b2bc2c2adb7b3d1941b2b60701ee2af8296fc8aad8b8bc48bc25767266c", size = 514383 }, + { url = "https://files.pythonhosted.org/packages/1b/1f/62e5d400603e8468cd635812d99cb81cfdc08127a3dc474c647615f31339/aiohttp-3.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c97989ae40a9746650fa196894f317dafc12227c808c774929dda0ff873a5954", size = 518304 }, + { url = "https://files.pythonhosted.org/packages/90/57/2326b37b10896447e3c6e0cbef4fe2486d30913639a5cfd1332b5d870f82/aiohttp-3.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dae86be9811493f9990ef44fff1685f5c1a3192e9061a71a109d527944eed551", size = 1893433 }, + { url = "https://files.pythonhosted.org/packages/d2/b4/a24d82112c304afdb650167ef2fe190957d81cbddac7460bedd245f765aa/aiohttp-3.13.4-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1db491abe852ca2fa6cc48a3341985b0174b3741838e1341b82ac82c8bd9e871", size = 1755901 }, + { url = "https://files.pythonhosted.org/packages/9e/2d/0883ef9d878d7846287f036c162a951968f22aabeef3ac97b0bea6f76d5d/aiohttp-3.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e5d701c0aad02a7dce72eef6b93226cf3734330f1a31d69ebbf69f33b86666e", size = 1876093 }, + { url = "https://files.pythonhosted.org/packages/ad/52/9204bb59c014869b71971addad6778f005daa72a96eed652c496789d7468/aiohttp-3.13.4-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8ac32a189081ae0a10ba18993f10f338ec94341f0d5df8fff348043962f3c6f8", size = 1970815 }, + { url = "https://files.pythonhosted.org/packages/d6/b5/e4eb20275a866dde0f570f411b36c6b48f7b53edfe4f4071aa1b0728098a/aiohttp-3.13.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98e968cdaba43e45c73c3f306fca418c8009a957733bac85937c9f9cf3f4de27", size = 1816223 }, + { url = "https://files.pythonhosted.org/packages/d8/23/e98075c5bb146aa61a1239ee1ac7714c85e814838d6cebbe37d3fe19214a/aiohttp-3.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca114790c9144c335d538852612d3e43ea0f075288f4849cf4b05d6cd2238ce7", size = 1649145 }, + { url = "https://files.pythonhosted.org/packages/d6/c1/7bad8be33bb06c2bb224b6468874346026092762cbec388c3bdb65a368ee/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ea2e071661ba9cfe11eabbc81ac5376eaeb3061f6e72ec4cc86d7cdd1ffbdbbb", size = 1816562 }, + { url = "https://files.pythonhosted.org/packages/5c/10/c00323348695e9a5e316825969c88463dcc24c7e9d443244b8a2c9cf2eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:34e89912b6c20e0fd80e07fa401fd218a410aa1ce9f1c2f1dad6db1bd0ce0927", size = 1800333 }, + { url = "https://files.pythonhosted.org/packages/84/43/9b2147a1df3559f49bd723e22905b46a46c068a53adb54abdca32c4de180/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0e217cf9f6a42908c52b46e42c568bd57adc39c9286ced31aaace614b6087965", size = 1820617 }, + { url = "https://files.pythonhosted.org/packages/a9/7f/b3481a81e7a586d02e99387b18c6dafff41285f6efd3daa2124c01f87eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:0c296f1221e21ba979f5ac1964c3b78cfde15c5c5f855ffd2caab337e9cd9182", size = 1643417 }, + { url = "https://files.pythonhosted.org/packages/8f/72/07181226bc99ce1124e0f89280f5221a82d3ae6a6d9d1973ce429d48e52b/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d99a9d168ebaffb74f36d011750e490085ac418f4db926cce3989c8fe6cb6b1b", size = 1849286 }, + { url = "https://files.pythonhosted.org/packages/1a/e6/1b3566e103eca6da5be4ae6713e112a053725c584e96574caf117568ffef/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cb19177205d93b881f3f89e6081593676043a6828f59c78c17a0fd6c1fbed2ba", size = 1782635 }, + { url = "https://files.pythonhosted.org/packages/37/58/1b11c71904b8d079eb0c39fe664180dd1e14bebe5608e235d8bfbadc8929/aiohttp-3.13.4-cp314-cp314t-win32.whl", hash = "sha256:c606aa5656dab6552e52ca368e43869c916338346bfaf6304e15c58fb113ea30", size = 472537 }, + { url = "https://files.pythonhosted.org/packages/bc/8f/87c56a1a1977d7dddea5b31e12189665a140fdb48a71e9038ff90bb564ec/aiohttp-3.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:014dcc10ec8ab8db681f0d68e939d1e9286a5aa2b993cbbdb0db130853e02144", size = 506381 }, ] [[package]] @@ -3723,7 +3723,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.83.4" +version = "1.83.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3739,9 +3739,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/03/c4/30469c06ae7437a4406bc11e3c433cfd380a6771068cca15ea918dcd158f/litellm-1.83.4.tar.gz", hash = "sha256:6458d2030a41229460b321adee00517a91dbd8e63213cc953d355cb41d16f2d4", size = 17733899 } +sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/bd/df19d3f8f6654535ee343a341fd921f81c411abf601a53e3eaef58129b02/litellm-1.83.4-py3-none-any.whl", hash = "sha256:17d7b4d48d47aca988ea4f762ddda5e7bd72cda3270192b22813d0330869d7b4", size = 16015555 }, + { url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054 }, ] [[package]] @@ -5124,7 +5124,7 @@ wheels = [ [[package]] name = "openai" -version = "2.30.0" +version = "2.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -5136,9 +5136,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084 } +sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656 }, + { url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122 }, ] [[package]] @@ -6780,11 +6780,11 @@ wheels = [ [[package]] name = "python-dotenv" -version = "1.0.1" +version = "1.2.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101 }, ] [[package]] @@ -7947,7 +7947,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.19" +version = "0.0.21" source = { editable = "." } dependencies = [ { name = "alembic" }, @@ -8045,7 +8045,7 @@ requires-dist = [ { name = "composio", specifier = ">=0.10.9" }, { name = "datasets", specifier = ">=2.21.0" }, { name = "daytona", specifier = ">=0.146.0" }, - { name = "deepagents", specifier = ">=0.4.12" }, + { name = "deepagents", specifier = ">=0.4.12,<0.5" }, { name = "discord-py", specifier = ">=2.5.2" }, { name = "docling", specifier = ">=2.15.0" }, { name = "elasticsearch", specifier = ">=9.1.1" }, @@ -8070,7 +8070,7 @@ requires-dist = [ { name = "langgraph", specifier = ">=1.1.3" }, { name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" }, { name = "linkup-sdk", specifier = ">=0.2.4" }, - { name = "litellm", specifier = ">=1.83.4" }, + { name = "litellm", specifier = ">=1.83.7" }, { name = "llama-cloud-services", specifier = ">=0.6.25" }, { name = "markdown", specifier = ">=3.7" }, { name = "markdownify", specifier = ">=0.14.1" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index 146dd177e..f127b85c0 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.19", + "version": "0.0.21", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/build/entitlements.mac.plist b/surfsense_desktop/build/entitlements.mac.plist new file mode 100644 index 000000000..5647e7759 --- /dev/null +++ b/surfsense_desktop/build/entitlements.mac.plist @@ -0,0 +1,35 @@ + + + + + + com.apple.security.cs.allow-jit + + com.apple.security.cs.allow-unsigned-executable-memory + + + + com.apple.security.cs.allow-dyld-environment-variables + + com.apple.security.cs.disable-library-validation + + + + com.apple.security.network.client + + com.apple.security.network.server + + + + com.apple.security.device.camera + + + + com.apple.security.automation.apple-events + + + + com.apple.security.files.user-selected.read-write + + + diff --git a/surfsense_desktop/electron-builder.yml b/surfsense_desktop/electron-builder.yml index b0014a57b..e4e7670ec 100644 --- a/surfsense_desktop/electron-builder.yml +++ b/surfsense_desktop/electron-builder.yml @@ -46,8 +46,11 @@ mac: icon: assets/icon.icns category: public.app-category.productivity artifactName: "${productName}-${version}-${arch}.${ext}" - hardenedRuntime: false + hardenedRuntime: true gatekeeperAssess: false + entitlements: build/entitlements.mac.plist + entitlementsInherit: build/entitlements.mac.plist + notarize: true extendInfo: NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists." NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer." diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index e2712d8ea..4826b904e 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,6 +1,6 @@ { "name": "surfsense-desktop", - "version": "0.0.19", + "version": "0.0.21", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index 8d9ed5cb1..3ddd5195f 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -127,7 +127,7 @@ const FAQ_ITEMS = [ { question: "What happens after I use my free tokens?", answer: - "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.", + "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.", }, { question: "Is Claude AI available without login?", @@ -329,7 +329,7 @@ export default async function FreeHubPage() {

Want More Features?

- Create a free SurfSense account to unlock 3 million tokens, document uploads with + Create a free SurfSense account to unlock $5 of premium credit, document uploads with citations, team collaboration, and integrations with Slack, Google Drive, Notion, and 30+ more tools.

diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx index 6ad9435bf..2a413b9a9 100644 --- a/surfsense_web/app/(home)/pricing/page.tsx +++ b/surfsense_web/app/(home)/pricing/page.tsx @@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav"; export const metadata: Metadata = { title: "Pricing | SurfSense - Free AI Search Plans", description: - "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.", + "Explore SurfSense plans and pricing. Start free with 500 pages & $5 in premium credits. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost.", alternates: { canonical: "https://surfsense.com/pricing", }, diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx index 3017160e1..0c5662712 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx @@ -8,7 +8,7 @@ import { cn } from "@/lib/utils"; const TABS = [ { id: "pages", label: "Pages" }, - { id: "tokens", label: "Premium Tokens" }, + { id: "tokens", label: "Premium Credit" }, ] as const; type TabId = (typeof TABS)[number]["id"]; diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index bb8f62703..9b5510df3 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -13,6 +13,7 @@ import { useParams } from "next/navigation"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; import { clearTargetCommentIdAtom, @@ -30,6 +31,7 @@ import { clearPlanOwnerRegistry, // extractWriteTodosFromContent, } from "@/atoms/chat/plan-state.atom"; +import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom"; import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms"; import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; @@ -59,27 +61,35 @@ import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; +import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier"; +import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, looksLikePodcastRequest, setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; +import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; +import { + consumeSseEvents, + hasPersistableContent, + markInterruptsCompleted, + processSharedStreamEvent, +} from "@/lib/chat/stream-pipeline"; +import { + applyTurnIdToAssistantMessageList, + mergeChatTurnIdIntoMessage, + readStreamedChatTurnId, + readStreamedMessageId, +} from "@/lib/chat/stream-side-effects"; import { - addStepSeparator, addToolCall, - appendReasoning, - appendText, - appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, - endReasoning, - FrameBatchedUpdater, - readSSEStream, + type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; import { @@ -99,8 +109,9 @@ import { import { NotFoundError } from "@/lib/error"; import { type BundleSubmit, HitlBundleProvider } from "@/lib/hitl"; import { + trackChatBlocked, trackChatCreated, - trackChatError, + trackChatErrorDetailed, trackChatMessageSent, trackChatResponseReceived, } from "@/lib/posthog/events"; @@ -128,25 +139,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -/** - * After a tool produces output, mark any previously-decided interrupt tool - * calls as completed so the ApprovalCard can transition from shimmer to done. - */ -function markInterruptsCompleted(contentParts: Array<{ type: string; result?: unknown }>): void { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - (part.result as Record).__interrupt__ === true && - (part.result as Record).__decided__ && - !(part.result as Record).__completed__ - ) { - part.result = { ...(part.result as Record), __completed__: true }; - } - } -} - /** * Generate a synthetic ``toolCallId`` for an action_request that has no * matching streamed tool-call card (HITL-blocked subagent calls don't surface @@ -243,28 +235,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { * ``stream_new_chat.py``) keep the JSON from ballooning. */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; +const TURN_CANCELLING_INITIAL_DELAY_MS = 200; +const TURN_CANCELLING_BACKOFF_FACTOR = 2; +const TURN_CANCELLING_MAX_DELAY_MS = 1500; +const RECENT_CANCEL_WINDOW_MS = 5_000; -/** - * When a streamed message is persisted, the backend returns the durable - * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it - * into the assistant-ui message metadata so the per-turn "Revert turn" - * button can scope to this turn's actions even after a full chat reload. - */ -function mergeChatTurnIdIntoMessage( - msg: ThreadMessageLike, - turnId: string | null | undefined -): ThreadMessageLike { - if (!turnId) return msg; - const existingMeta = (msg.metadata ?? {}) as { custom?: Record }; - const existingCustom = existingMeta.custom ?? {}; - if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; - return { - ...msg, - metadata: { - ...existingMeta, - custom: { ...existingCustom, chatTurnId: turnId }, - }, - }; +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function computeFallbackTurnCancellingRetryDelay(attempt: number): number { + const safeAttempt = Math.max(1, attempt); + const raw = + TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); } export default function NewChatPage() { @@ -277,6 +261,7 @@ export default function NewChatPage() { const [isRunning, setIsRunning] = useState(false); const [tokenUsageStore] = useState(() => createTokenUsageStore()); const abortControllerRef = useRef(null); + const recentCancelRequestedAtRef = useRef(0); const [pendingInterrupt, setPendingInterrupt] = useState<{ threadId: number; assistantMsgId: string; @@ -284,6 +269,63 @@ export default function NewChatPage() { bundleToolCallIds: string[]; } | null>(null); const toolsWithUI = TOOLS_WITH_UI_ALL; + const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); + + const persistAssistantErrorMessage = useCallback( + async ({ + threadId, + assistantMsgId, + text, + }: { + threadId: number | null; + assistantMsgId: string; + text: string; + }) => { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [{ type: "text", text }], + } + : m + ) + ); + + if (!threadId) return; + + // Persist only temporary assistant placeholders to avoid duplicate rows + // when the message already has a database-backed ID. + if (!assistantMsgId.startsWith("msg-assistant-")) return; + + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: [{ type: "text", text }], + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + } catch (persistErr) { + console.error("Failed to persist assistant error message:", persistErr); + } + }, + [tokenUsageStore] + ); + + // NOTE: ``persistUserTurn`` / ``persistAssistantTurn`` callbacks + // were removed in the SSE-based message ID handshake refactor. + // ``stream_new_chat`` and ``stream_resume_chat`` now persist both + // the user and assistant rows server-side via + // ``persist_user_turn`` / ``persist_assistant_shell`` and emit + // ``data-user-message-id`` / ``data-assistant-message-id`` SSE + // events; the consumers below rename the optimistic ids in real + // time. ``persistAssistantErrorMessage`` (above) is intentionally + // kept — it is the pre-stream-error fallback fired when the + // server NEVER accepted the request, and the BE has nothing to + // persist in that case. // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -291,9 +333,10 @@ export default function NewChatPage() { // Get mentioned document IDs from the composer. const mentionedDocumentIds = useAtomValue(mentionedDocumentIdsAtom); const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); + const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); - const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setCurrentThreadState = useSetAtom(currentThreadAtom); + const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom); const closeReportPanel = useSetAtom(closeReportPanelAtom); @@ -317,6 +360,8 @@ export default function NewChatPage() { // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); + const { data: agentFlags } = useAtomValue(agentFlagsAtom); + const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true; // Live collaboration: sync session state and messages via Zero useChatSessionStateSync(threadId); @@ -408,6 +453,143 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.chat_id]); + const handleChatFailure = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + }) => { + const normalized = classifyChatError({ + error, + flow, + context: { + searchSpaceId, + threadId, + }, + }); + + const logger = + normalized.severity === "error" + ? console.error + : normalized.severity === "warn" + ? console.warn + : console.info; + logger(`[NewChatPage] ${flow} ${normalized.kind}:`, error); + + const telemetryPayload = { + flow, + kind: normalized.kind, + error_code: normalized.errorCode, + severity: normalized.severity, + is_expected: normalized.isExpected, + message: normalized.userMessage, + }; + if (normalized.telemetryEvent === "chat_blocked") { + trackChatBlocked(searchSpaceId, threadId, telemetryPayload); + } else { + trackChatErrorDetailed(searchSpaceId, threadId, telemetryPayload); + } + + if (normalized.channel === "silent") { + return; + } + + if (normalized.channel === "pinned_inline") { + if (threadId) { + setPremiumAlertForThread({ + threadId, + message: normalized.userMessage, + userId: currentUser?.id ?? null, + }); + } + if (normalized.assistantMessage) { + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: normalized.assistantMessage, + }); + } + return; + } + + toast.error(normalized.userMessage); + }, + [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] + ); + + const handleStreamTerminalError = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + accepted, + onAbort, + onPreAcceptFailure, + onAcceptedStreamError, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + accepted: boolean; + onAbort?: () => Promise; + onPreAcceptFailure?: () => Promise; + onAcceptedStreamError?: () => Promise; + }) => { + if (error instanceof Error && error.name === "AbortError") { + await onAbort?.(); + return; + } + + if (!accepted) { + await onPreAcceptFailure?.(); + } else { + await onAcceptedStreamError?.(); + } + + await handleChatFailure({ + error: !accepted ? tagPreAcceptSendFailure(error) : error, + flow, + threadId, + assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", + }); + }, + [handleChatFailure] + ); + + const fetchWithTurnCancellingRetry = useCallback(async (runFetch: () => Promise) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; + } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } + + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, []); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -577,12 +759,39 @@ export default function NewChatPage() { // Cancel ongoing request const cancelRun = useCallback(async () => { + if (threadId) { + const token = getBearerToken(); + if (token) { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + try { + const response = await fetch( + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, + { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + }, + } + ); + if (response.ok) { + const payload = (await response.json()) as { + error_code?: string; + }; + if (payload.error_code === "TURN_CANCELLING") { + recentCancelRequestedAtRef.current = Date.now(); + } + } + } catch (error) { + console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error); + } + } + } if (abortControllerRef.current) { abortControllerRef.current.abort(); abortControllerRef.current = null; } setIsRunning(false); - }, []); + }, [threadId]); // Handle new message from user const onNew = useCallback( @@ -634,7 +843,12 @@ export default function NewChatPage() { ); } catch (error) { console.error("[NewChatPage] Failed to create thread:", error); - toast.error("Failed to start chat. Please try again."); + await handleChatFailure({ + error: tagPreAcceptSendFailure(error), + flow: "new", + threadId: currentThreadId, + assistantMsgId: "no-persist-assistant", + }); return; } } @@ -643,8 +857,13 @@ export default function NewChatPage() { setPendingUserImageUrls((prev) => prev.filter((u) => !urlsSnapshot.includes(u))); } - // Add user message to state - const userMsgId = `msg-user-${Date.now()}`; + // Add user message to state. Mutable because the SSE + // ``data-user-message-id`` handler (below) renames this + // optimistic id to the canonical ``msg-{db_id}`` once the + // backend's ``persist_user_turn`` resolves the row, and + // the in-stream flush / interrupt closures need to see + // the post-rename value via this live ``let`` binding. + let userMsgId = `msg-user-${Date.now()}`; // Always include author metadata so the UI layer can decide visibility const authorMetadata = currentUser @@ -710,72 +929,33 @@ export default function NewChatPage() { })); } - const persistContent: unknown[] = [...userDisplayContent]; - - if (allMentionedDocs.length > 0) { - persistContent.push({ - type: "mentioned-documents", - documents: allMentionedDocs, - }); - } - - appendMessage(currentThreadId, { - role: "user", - content: persistContent, - }) - .then((savedMessage) => { - const newUserMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); - setMessageDocumentsMap((prev) => { - const docs = prev[userMsgId]; - if (!docs) return prev; - const { [userMsgId]: _, ...rest } = prev; - return { ...rest, [newUserMsgId]: docs }; - }); - if (isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - }) - .catch((err) => console.error("Failed to persist user message:", err)); - // Start streaming response setIsRunning(true); const controller = new AbortController(); abortControllerRef.current = controller; - // Prepare assistant message - const assistantMsgId = `msg-assistant-${Date.now()}`; + // Prepare assistant message. Mutable for the same reason + // as ``userMsgId`` above — the ``data-assistant-message-id`` + // SSE handler reassigns this once + // ``persist_assistant_shell`` returns its canonical id. + let assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map(); - const batcher = new FrameBatchedUpdater(); - const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; + const { contentParts } = contentPartsState; let wasInterrupted = false; - let tokenUsageData: Record | null = null; - // Captured from ``data-turn-info`` at stream start. - let streamedChatTurnId: string | null = null; - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); + let newAccepted = false; + let streamBatcher: FrameBatchedUpdater | null = null; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); if ( selection.filesystem_mode === "desktop_local_folder" && (!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0) @@ -807,33 +987,59 @@ export default function NewChatPage() { setMentionedDocuments([]); } - const response = await fetch(`${backendUrl}/api/v1/new_chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - chat_id: currentThreadId, - user_query: userQuery.trim(), - search_space_id: searchSpaceId, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - messages: messageHistory, - mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, - mentioned_surfsense_doc_ids: hasSurfsenseDocIds - ? mentionedDocumentIds.surfsense_doc_ids - : undefined, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - ...(userImages.length > 0 ? { user_images: userImages } : {}), - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/new_chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chat_id: currentThreadId, + user_query: userQuery.trim(), + search_space_id: searchSpaceId, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + messages: messageHistory, + mentioned_document_ids: hasDocumentIds + ? mentionedDocumentIds.document_ids + : undefined, + mentioned_surfsense_doc_ids: hasSurfsenseDocIds + ? mentionedDocumentIds.surfsense_doc_ids + : undefined, + // Full mention metadata so the BE can embed a + // ``mentioned-documents`` ContentPart on the + // persisted user message (replaces the old FE-side + // injection in ``persistUserTurn``). + mentioned_documents: + allMentionedDocs.length > 0 + ? allMentionedDocs.map((d) => ({ + id: d.id, + title: d.title, + document_type: d.document_type, + })) + : undefined, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + ...(userImages.length > 0 ? { user_images: userImages } : {}), + }), + signal: controller.signal, + }) + ); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + newAccepted = true; + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); const flushMessages = () => { setMessages((prev) => @@ -844,123 +1050,41 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text - // streamed). ``scheduleFlush(); batcher.flush()`` sets - // the dirty bit FIRST so terminal events render - // promptly without the 50ms throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens - // of times per call, so use throttled - // scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; - // backfill via updateToolCall so the new card - // renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageStore.set(assistantMsgId, data); + }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - } - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-thread-title-update": { const titleData = parsed.data as { threadId: number; title: string }; if (titleData?.title && titleData?.threadId === currentThreadId) { @@ -1006,13 +1130,21 @@ export default function NewChatPage() { name: string; args: Record; }>; - const paired = pairBundleToolCallIds(toolCallIndices, contentParts, actionRequests); + const paired = pairBundleToolCallIds( + contentPartsState.toolCallIndices, + contentPartsState.contentParts, + actionRequests + ); const bundleToolCallIds: string[] = []; for (let i = 0; i < actionRequests.length; i++) { const action = actionRequests[i]; let targetTcId = paired[i]; if (!targetTcId) { - targetTcId = freshSynthToolCallId(toolCallIndices, action.name, i); + targetTcId = freshSynthToolCallId( + contentPartsState.toolCallIndices, + action.name, + i + ); addToolCall( contentPartsState, toolsWithUI, @@ -1061,125 +1193,131 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; } - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw new Error(parsed.errorText || "Server error"); - } - } - - batcher.flush(); - - // Skip persistence for interrupted messages -- handleResume will persist the final version - const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); - if (contentParts.length > 0 && !wasInterrupted) { - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - turn_id: streamedChatTurnId, - }); - - // Update message ID from temporary to database ID so comments work immediately - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - - // Update pending interrupt with the new persisted message ID - setPendingInterrupt((prev) => - prev && prev.assistantMsgId === assistantMsgId - ? { ...prev, assistantMsgId: newMsgId } - : prev - ); - } catch (err) { - console.error("Failed to persist assistant message:", err); - } - - // Track successful response - trackChatResponseReceived(searchSpaceId, currentThreadId); - } - } catch (error) { - batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - // Request was cancelled by user - persist partial response if any content was received - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); - if (hasContent && currentThreadId) { - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: partialContent, - turn_id: streamedChatTurnId, - }); - - // Update message ID from temporary to database ID - const newMsgId = `msg-${savedMessage.id}`; + case "data-user-message-id": { + // Server-authoritative user message id resolved by + // ``persist_user_turn`` (or recovered via ON CONFLICT). + // Rename the optimistic ``msg-user-XXX`` placeholder to + // the canonical ``msg-{db_id}`` so DB-id-gated UI + // (comments, edit-from-this-message) unlocks immediately, + // migrate the local mentioned-documents map, and reassign + // the closure variable so all downstream + // ``m.id === userMsgId`` checks see the new value. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newUserMsgId = `msg-${parsedMsg.messageId}`; + const oldUserMsgId = userMsgId; setMessages((prev) => prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + m.id === oldUserMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, parsedMsg.turnId) : m ) ); - } catch (err) { - console.error("Failed to persist partial assistant message:", err); + if (allMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + if (!(oldUserMsgId in prev)) { + return { ...prev, [newUserMsgId]: allMentionedDocs }; + } + const { [oldUserMsgId]: _removed, ...rest } = prev; + return { ...rest, [newUserMsgId]: allMentionedDocs }; + }); + } + userMsgId = newUserMsgId; + if (isNewThread) { + // First user-side row landed in ``new_chat_messages``; + // refresh the sidebar so the freshly-bumped + // ``thread.updated_at`` reorders this thread. + queryClient.invalidateQueries({ + queryKey: ["threads", String(searchSpaceId)], + }); + } + break; + } + + case "data-assistant-message-id": { + // Server-authoritative assistant message id resolved + // by ``persist_assistant_shell``. Rename the optimistic + // id, migrate ``tokenUsageStore`` so any pending + // ``data-token-usage`` payload binds to the new id, + // remap any in-flight ``pendingInterrupt`` reference, + // and reassign the closure variable so the in-stream + // flush callback (line ~1074) keeps writing to the + // renamed message. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newAssistantMsgId = `msg-${parsedMsg.messageId}`; + const oldAssistantMsgId = assistantMsgId; + tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId); + setMessages((prev) => + prev.map((m) => + m.id === oldAssistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newAssistantMsgId }, parsedMsg.turnId) + : m + ) + ); + setPendingInterrupt((prev) => + prev && prev.assistantMsgId === oldAssistantMsgId + ? { ...prev, assistantMsgId: newAssistantMsgId } + : prev + ); + assistantMsgId = newAssistantMsgId; + break; } } - return; + }); + + batcher.flush(); + + // Server-authoritative persistence: ``stream_new_chat`` + // already wrote the user row in ``persist_user_turn`` + // (the FE renamed the optimistic id mid-stream via + // ``data-user-message-id``) and finalises the assistant + // row in ``finalize_assistant_turn`` from a shielded + // ``finally`` block. Nothing left for the FE to persist + // here — track the response and unblock the UI. + if (contentParts.length > 0 && !wasInterrupted) { + trackChatResponseReceived(searchSpaceId, currentThreadId); } - console.error("[NewChatPage] Chat error:", error); - - // Track chat error - trackChatError( - searchSpaceId, - currentThreadId, - error instanceof Error ? error.message : "Unknown error" - ); - - toast.error("Failed to get response. Please try again."); - // Update assistant message with error - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [ - { - type: "text", - text: "Sorry, there was an error. Please try again.", - }, - ], - } - : m - ) - ); + } catch (error) { + streamBatcher?.dispose(); + await handleStreamTerminalError({ + error, + flow: "new", + threadId: currentThreadId, + assistantMsgId, + accepted: newAccepted, + // Server-side ``finalize_assistant_turn`` runs from a + // shielded ``anyio.CancelScope(shield=True)`` finally + // block, so partial content (incl. abort-mid-stream) + // is already persisted by the BE for the assistant + // row, and ``persist_user_turn`` ran before any LLM + // call. The FE's only remaining responsibility on + // abort / accepted-stream-error is to surface the + // error toast (handled by ``handleStreamTerminalError`` + // itself). + onPreAcceptFailure: async () => { + // Pre-accept failure means the BE never accepted the + // request — no server-side persistence ran. Roll + // back the optimistic UI insertions we made before + // the fetch so the user message and any local + // mentioned-docs metadata don't linger. + setMessages((prev) => prev.filter((m) => m.id !== userMsgId)); + setMessageDocumentsMap((prev) => { + if (!(userMsgId in prev)) return prev; + const { [userMsgId]: _removed, ...rest } = prev; + return rest; + }); + }, + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1196,12 +1334,15 @@ export default function NewChatPage() { setAgentCreatedDocuments, queryClient, currentUser, + localFilesystemEnabled, disabledTools, updateChatTabTitle, tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, - toolsWithUI, + fetchWithTurnCancellingRetry, + handleStreamTerminalError, + handleChatFailure, ] ); @@ -1214,7 +1355,12 @@ export default function NewChatPage() { }> ) => { if (!pendingInterrupt) return; - const { threadId: resumeThreadId, assistantMsgId } = pendingInterrupt; + const { threadId: resumeThreadId } = pendingInterrupt; + // Destructured separately as ``let`` so the SSE + // ``data-assistant-message-id`` handler (resume always + // allocates a fresh server-side row) can rename it to + // the canonical ``msg-{db_id}`` mid-stream. + let assistantMsgId = pendingInterrupt.assistantMsgId; setPendingInterrupt(null); setIsRunning(true); @@ -1229,7 +1375,6 @@ export default function NewChatPage() { abortControllerRef.current = controller; const currentThinkingSteps = new Map(); - const batcher = new FrameBatchedUpdater(); const contentPartsState: ContentPartsState = { contentParts: [], @@ -1238,9 +1383,8 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; - let tokenUsageData: Record | null = null; - // Captured from ``data-turn-info`` at stream start. - let streamedChatTurnId: string | null = null; + let resumeAccepted = false; + let streamBatcher: FrameBatchedUpdater | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1318,27 +1462,32 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(searchSpaceId); - const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - search_space_id: searchSpaceId, - decisions, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - }), - signal: controller.signal, + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + decisions, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + }), + signal: controller.signal, + }) + ); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + resumeAccepted = true; const flushMessages = () => { setMessages((prev) => @@ -1349,115 +1498,51 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageStore.set(assistantMsgId, data); + }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-interrupt-request": { const interruptData = parsed.data as Record; const actionRequests = (interruptData.action_requests ?? []) as Array<{ name: string; args: Record; }>; - const paired = pairBundleToolCallIds(toolCallIndices, contentParts, actionRequests); + const paired = pairBundleToolCallIds( + contentPartsState.toolCallIndices, + contentPartsState.contentParts, + actionRequests + ); const bundleToolCallIds: string[] = []; for (let i = 0; i < actionRequests.length; i++) { const action = actionRequests[i]; let targetTcId = paired[i]; if (!targetTcId) { - targetTcId = freshSynthToolCallId(toolCallIndices, action.name, i); + targetTcId = freshSynthToolCallId( + contentPartsState.toolCallIndices, + action.name, + i + ); addToolCall( contentPartsState, toolsWithUI, @@ -1504,64 +1589,74 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; } - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); + case "data-assistant-message-id": { + // Resume always allocates a fresh ``new_chat_messages`` + // row anchored to a new ``turn_id`` (the original + // interrupted turn's row stays as-is), so this is a + // real id swap. Rename the optimistic placeholder to + // ``msg-{db_id}`` and reassign closure state. Resume + // does NOT emit ``data-user-message-id`` — the user + // row belongs to the original interrupted turn. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newAssistantMsgId = `msg-${parsedMsg.messageId}`; + const oldAssistantMsgId = assistantMsgId; + tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId); + setMessages((prev) => + prev.map((m) => + m.id === oldAssistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newAssistantMsgId }, parsedMsg.turnId) + : m + ) + ); + assistantMsgId = newAssistantMsgId; break; - - case "error": - throw new Error(parsed.errorText || "Server error"); + } } - } + }); batcher.flush(); - const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); - if (contentParts.length > 0) { - try { - const savedMessage = await appendMessage(resumeThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - turn_id: streamedChatTurnId, - }); - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - } catch (err) { - console.error("Failed to persist resumed assistant message:", err); - } - } + // Server-authoritative persistence: ``stream_resume_chat`` + // finalises the assistant row in + // ``finalize_assistant_turn`` from a shielded + // ``finally`` block (covers both happy-path and + // abort-mid-stream). FE has no remaining persistence + // work here. } catch (error) { - batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - return; - } - console.error("[NewChatPage] Resume error:", error); - toast.error("Failed to resume. Please try again."); + streamBatcher?.dispose(); + await handleStreamTerminalError({ + error, + flow: "resume", + threadId: resumeThreadId, + assistantMsgId, + accepted: resumeAccepted, + }); } finally { setIsRunning(false); abortControllerRef.current = null; } }, - [pendingInterrupt, messages, searchSpaceId, tokenUsageStore, toolsWithUI] + [ + pendingInterrupt, + messages, + searchSpaceId, + localFilesystemEnabled, + disabledTools, + queryClient, + tokenUsageStore, + fetchWithTurnCancellingRetry, + handleStreamTerminalError, + ] ); useEffect(() => { @@ -1705,6 +1800,7 @@ export default function NewChatPage() { editExtras?: { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; + sourceUserMessageId?: string; }, editFromPosition?: { /** Message id (numeric, parsed from ``msg-``) to rewind to. */ @@ -1736,11 +1832,13 @@ export default function NewChatPage() { let userQueryToDisplay: string | undefined; let originalUserMessageContent: ThreadMessageLike["content"] | null = null; let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined; + let sourceUserMessageId: string | undefined = editExtras?.sourceUserMessageId; if (!isEdit) { // Reload mode - find and preserve the last user message content const lastUserMessage = [...messages].reverse().find((m) => m.role === "user"); if (lastUserMessage) { + sourceUserMessageId = lastUserMessage.id; originalUserMessageContent = lastUserMessage.content; originalUserMessageMetadata = lastUserMessage.metadata; // Extract text for the API request @@ -1755,34 +1853,17 @@ export default function NewChatPage() { userQueryToDisplay = newUserQuery; } - // Remove downstream messages from the UI immediately. The - // backend will also delete them from the database. - // - // When an explicit ``fromMessageId`` is passed, slice from - // that message forward; otherwise fall back to the legacy - // "drop the last 2" behaviour. - setMessages((prev) => { - if (editFromPosition?.fromMessageId != null) { - const targetId = `msg-${editFromPosition.fromMessageId}`; - const sliceIndex = prev.findIndex((m) => m.id === targetId); - if (sliceIndex >= 0) { - return prev.slice(0, sliceIndex); - } - } - if (prev.length >= 2) { - return prev.slice(0, -2); - } - return prev; - }); - // Start streaming setIsRunning(true); const controller = new AbortController(); abortControllerRef.current = controller; - // Add placeholder user message if we have a new query (edit mode) - const userMsgId = `msg-user-${Date.now()}`; - const assistantMsgId = `msg-assistant-${Date.now()}`; + // Add placeholder user message if we have a new query (edit mode). + // Mutable for the same reason as in ``onNew`` — both ids are + // renamed mid-stream by the new ``data-user-message-id`` / + // ``data-assistant-message-id`` SSE handlers below. + let userMsgId = `msg-user-${Date.now()}`; + let assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map(); const contentPartsState: ContentPartsState = { @@ -1791,13 +1872,9 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; - const batcher = new FrameBatchedUpdater(); - let tokenUsageData: Record | null = null; - // Captured from ``data-turn-info`` at stream start; stamped - // onto persisted messages so future edits can locate the - // right LangGraph checkpoint. - let streamedChatTurnId: string | null = null; + const { contentParts } = contentPartsState; + let regenerateAccepted = false; + let streamBatcher: FrameBatchedUpdater | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1810,21 +1887,14 @@ export default function NewChatPage() { createdAt: new Date(), metadata: isEdit ? undefined : originalUserMessageMetadata, }; - setMessages((prev) => [...prev, userMessage]); - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); - + const sourceMentionedDocs = + sourceUserMessageId && messageDocumentsMap[sourceUserMessageId] + ? messageDocumentsMap[sourceUserMessageId] + : []; try { - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); const requestBody: Record = { search_space_id: searchSpaceId, user_query: newUserQuery, @@ -1832,6 +1902,18 @@ export default function NewChatPage() { filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, local_filesystem_mounts: selection.local_filesystem_mounts, + // Full mention metadata for the regenerate-specific + // source list. Only meaningful for edit (the BE only + // re-persists a user row when ``user_query`` is set); + // reload reuses the original turn's mentioned_documents. + mentioned_documents: + sourceMentionedDocs.length > 0 + ? sourceMentionedDocs.map((d) => ({ + id: d.id, + title: d.title, + document_type: d.document_type, + })) + : undefined, }; if (isEdit) { requestBody.user_images = editExtras?.userImages ?? []; @@ -1846,18 +1928,56 @@ export default function NewChatPage() { requestBody.revert_actions = true; } } - const response = await fetch(getRegenerateUrl(threadId), { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(requestBody), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(getRegenerateUrl(threadId), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify(requestBody), + signal: controller.signal, + }) + ); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); + } + regenerateAccepted = true; + + // Only switch UI to regenerated placeholder messages after the backend accepts + // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). + // + // When an explicit ``editFromPosition.fromMessageId`` is passed, slice from + // that message forward so edit-from-arbitrary-position drops every downstream + // message; otherwise fall back to the legacy "drop the last 2" behaviour. + setMessages((prev) => { + let base = prev; + if (editFromPosition?.fromMessageId != null) { + const targetId = `msg-${editFromPosition.fromMessageId}`; + const sliceIndex = prev.findIndex((m) => m.id === targetId); + if (sliceIndex >= 0) { + base = prev.slice(0, sliceIndex); + } + } else if (prev.length >= 2) { + base = prev.slice(0, -2); + } + return [ + ...base, + userMessage, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]; + }); + if (sourceMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => ({ + ...prev, + [userMsgId]: sourceMentionedDocs, + })); } const flushMessages = () => { @@ -1869,111 +1989,41 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageStore.set(assistantMsgId, data); + }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-action-log": { if (threadId !== null) { applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); @@ -1994,17 +2044,60 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; } + case "data-user-message-id": { + // Same role as in ``onNew`` but the regenerate-specific + // mention metadata (``sourceMentionedDocs``) is the + // list to migrate onto the canonical id key. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newUserMsgId = `msg-${parsedMsg.messageId}`; + const oldUserMsgId = userMsgId; + setMessages((prev) => + prev.map((m) => + m.id === oldUserMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, parsedMsg.turnId) + : m + ) + ); + if (sourceMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + if (!(oldUserMsgId in prev)) { + return { ...prev, [newUserMsgId]: sourceMentionedDocs }; + } + const { [oldUserMsgId]: _removed, ...rest } = prev; + return { ...rest, [newUserMsgId]: sourceMentionedDocs }; + }); + } + userMsgId = newUserMsgId; + break; + } + + case "data-assistant-message-id": { + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newAssistantMsgId = `msg-${parsedMsg.messageId}`; + const oldAssistantMsgId = assistantMsgId; + tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId); + setMessages((prev) => + prev.map((m) => + m.id === oldAssistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newAssistantMsgId }, parsedMsg.turnId) + : m + ) + ); + assistantMsgId = newAssistantMsgId; + break; + } + case "data-revert-results": { const summary = parsed.data; // failureCount must include every "not undone" bucket @@ -2040,95 +2133,48 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw new Error(parsed.errorText || "Server error"); } - } + }); batcher.flush(); - // Persist messages after streaming completes - const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); + // Server-authoritative persistence: ``stream_new_chat`` + // (regenerate flow) wrote the user row in + // ``persist_user_turn`` and finalises the assistant row + // in ``finalize_assistant_turn`` from a shielded + // ``finally`` block (covers both happy-path and + // abort-mid-stream). FE only needs to track the + // successful response here. if (contentParts.length > 0) { - try { - // Persist user message (for both edit and reload modes, since backend deleted it) - const userContentToPersist = isEdit - ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) - : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; - - const savedUserMessage = await appendMessage(threadId, { - role: "user", - content: userContentToPersist, - turn_id: streamedChatTurnId, - }); - - // Update user message ID to database ID - const newUserMsgId = `msg-${savedUserMessage.id}`; - setMessages((prev) => - prev.map((m) => - m.id === userMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) - : m - ) - ); - - // Persist assistant message - const savedMessage = await appendMessage(threadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - turn_id: streamedChatTurnId, - }); - - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - - trackChatResponseReceived(searchSpaceId, threadId); - } catch (err) { - console.error("Failed to persist regenerated message:", err); - } + trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - if (error instanceof Error && error.name === "AbortError") { - return; - } - batcher.dispose(); - console.error("[NewChatPage] Regeneration error:", error); - trackChatError( - searchSpaceId, + streamBatcher?.dispose(); + await handleStreamTerminalError({ + error, + flow: "regenerate", threadId, - error instanceof Error ? error.message : "Unknown error" - ); - toast.error("Failed to regenerate response. Please try again."); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [{ type: "text", text: "Sorry, there was an error. Please try again." }], - } - : m - ) - ); + assistantMsgId, + accepted: regenerateAccepted, + }); } finally { setIsRunning(false); abortControllerRef.current = null; } }, - [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] + [ + threadId, + searchSpaceId, + messages, + disabledTools, + localFilesystemEnabled, + messageDocumentsMap, + setMessageDocumentsMap, + queryClient, + tokenUsageStore, + fetchWithTurnCancellingRetry, + handleStreamTerminalError, + ] ); // Handle editing a message - truncates history and regenerates with new query. @@ -2162,7 +2208,11 @@ export default function NewChatPage() { if (fromMessageId == null) { // No source id (or non-DB id) — fall back to today's // last-2 behaviour. The user gets the legacy edit flow. - await handleRegenerate(queryForApi, { userMessageContent, userImages }); + await handleRegenerate(queryForApi, { + userMessageContent, + userImages, + sourceUserMessageId: sourceId, + }); return; } @@ -2211,7 +2261,7 @@ export default function NewChatPage() { // Nothing to revert — submit silently. await handleRegenerate( queryForApi, - { userMessageContent, userImages }, + { userMessageContent, userImages, sourceUserMessageId: sourceId }, { fromMessageId, revertActions: false } ); return; @@ -2246,6 +2296,7 @@ export default function NewChatPage() { { userMessageContent: pending.userMessageContent, userImages: pending.userImages, + sourceUserMessageId: `msg-${pending.fromMessageId}`, }, { fromMessageId: pending.fromMessageId, diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx index 67d9edab0..85bc4aaa6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx @@ -1,11 +1,8 @@ "use client"; -import { useQueryClient } from "@tanstack/react-query"; import { CheckCircle2 } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; -import { useEffect } from "react"; -import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; import { Button } from "@/components/ui/button"; import { Card, @@ -18,14 +15,8 @@ import { export default function PurchaseSuccessPage() { const params = useParams(); - const queryClient = useQueryClient(); const searchSpaceId = String(params.search_space_id ?? ""); - useEffect(() => { - void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY }); - void queryClient.invalidateQueries({ queryKey: ["token-status"] }); - }, [queryClient]); - return (
diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx index bd8f03a70..17d8aa50c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx @@ -178,6 +178,19 @@ const FLAG_GROUPS: FlagGroup[] = [ }, ], }, + { + id: "desktop", + title: "Desktop", + subtitle: "Desktop-only capabilities exposed by the backend deployment.", + flags: [ + { + key: "enable_desktop_local_filesystem", + label: "Local filesystem", + description: "Allow Desktop chat sessions to operate directly on selected local folders.", + envVar: "ENABLE_DESKTOP_LOCAL_FILESYSTEM", + }, + ], + }, ]; function FlagRow({ def, value }: { def: FlagDef; value: boolean }) { diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx index 2b7422f80..cf73b5eba 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx @@ -28,6 +28,12 @@ type UnifiedPurchase = { kind: PurchaseKind; created_at: string; status: PagePurchaseStatus; + /** + * Granted units. Interpretation depends on ``kind``: + * - ``"pages"`` — integer number of indexed pages. + * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00). + * The ``Granted`` column formats accordingly. + */ granted: number; amount_total: number | null; currency: string | null; @@ -58,7 +64,7 @@ const KIND_META: Record< iconClass: "text-sky-500", }, tokens: { - label: "Premium Tokens", + label: "Premium Credit", icon: Coins, iconClass: "text-amber-500", }, @@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase { kind: "tokens", created_at: p.created_at, status: p.status, - granted: p.tokens_granted, + granted: p.credit_micros_granted, amount_total: p.amount_total, currency: p.currency, }; } +function formatGranted(p: UnifiedPurchase): string { + if (p.kind === "tokens") { + const dollars = p.granted / 1_000_000; + // Premium credit packs are always whole dollars at the moment, but + // future fractional grants (refunds, partial top-ups) shouldn't + // silently round to "$0". + if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`; + if (dollars > 0) return `$${dollars.toFixed(3)} of credit`; + return "$0 of credit"; + } + return p.granted.toLocaleString(); +} + export function PurchaseHistoryContent() { const results = useQueries({ queries: [ @@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {

No purchases yet

- Your page and premium token purchases will appear here after checkout. + Your page and premium credit purchases will appear here after checkout.

); @@ -177,7 +196,7 @@ export function PurchaseHistoryContent() { - {p.granted.toLocaleString()} + {formatGranted(p)} {formatAmount(p.amount_total, p.currency)} diff --git a/surfsense_web/app/desktop/permissions/page.tsx b/surfsense_web/app/desktop/permissions/page.tsx index e30a76f83..ca9228272 100644 --- a/surfsense_web/app/desktop/permissions/page.tsx +++ b/surfsense_web/app/desktop/permissions/page.tsx @@ -132,8 +132,8 @@ export default function DesktopPermissionsPage() {

System Permissions

- SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that - require focusing the app or the active application. + SurfSense needs two macOS permissions for Screenshot Assist and for desktop features + that require focusing the app or the active application.

diff --git a/surfsense_web/atoms/chat/current-thread.atom.ts b/surfsense_web/atoms/chat/current-thread.atom.ts index d781df8d2..131c98309 100644 --- a/surfsense_web/atoms/chat/current-thread.atom.ts +++ b/surfsense_web/atoms/chat/current-thread.atom.ts @@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat export const resetCurrentThreadAtom = atom(null, (_, set) => { set(currentThreadAtom, initialState); - set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null }); + set(reportPanelAtom, { + isOpen: false, + reportId: null, + title: null, + wordCount: null, + shareToken: null, + contentType: "markdown", + }); }); /** Target comment ID to scroll to (from URL navigation or inbox click) */ diff --git a/surfsense_web/atoms/chat/premium-alert.atom.ts b/surfsense_web/atoms/chat/premium-alert.atom.ts new file mode 100644 index 000000000..1c837dd65 --- /dev/null +++ b/surfsense_web/atoms/chat/premium-alert.atom.ts @@ -0,0 +1,45 @@ +import { atom } from "jotai"; + +export type PremiumAlertState = { + message: string; +}; + +export const premiumAlertByThreadAtom = atom>({}); + +export const setPremiumAlertForThreadAtom = atom( + null, + ( + get, + set, + payload: { + threadId: number; + message: string; + userId?: string | null; + } + ) => { + const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`; + + if (typeof window !== "undefined") { + const hasSeen = localStorage.getItem(storageKey) === "true"; + if (hasSeen) return; + } + + const current = get(premiumAlertByThreadAtom); + set(premiumAlertByThreadAtom, { + ...current, + [payload.threadId]: { message: payload.message }, + }); + + if (typeof window !== "undefined") { + localStorage.setItem(storageKey, "true"); + } + } +); + +export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => { + const current = get(premiumAlertByThreadAtom); + if (!(threadId in current)) return; + const next = { ...current }; + delete next[threadId]; + set(premiumAlertByThreadAtom, next); +}); diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index 8e196c9c7..4b6717440 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe(); export const currentUserAtom = atomWithQuery(() => { return { queryKey: USER_QUERY_KEY, - staleTime: 5 * 60 * 1000, + // Live-changing numeric fields (pages_*, premium_credit_micros_*) + // are now pushed via Zero (queries.user.me()), so /users/me only + // needs to fire once per session for the static profile fields. + staleTime: Infinity, enabled: !!getBearerToken(), queryFn: userQueryFn, }; diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx index 32c25771a..7d27b4019 100644 --- a/surfsense_web/components/agent-action-log/action-log-sheet.tsx +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -17,10 +17,7 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Skeleton } from "@/components/ui/skeleton"; -import { - agentActionsQueryKey, - useAgentActionsQuery, -} from "@/hooks/use-agent-actions-query"; +import { agentActionsQueryKey, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; import { ActionLogItem } from "./action-log-item"; function EmptyState() { diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 048837c89..7bccc22ee 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -400,6 +400,19 @@ function formatMessageDate(date: Date): string { }); } +/** + * Format provider USD cost (in micro-USD) for inline display next to a + * token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent + * costs so a real-but-tiny figure doesn't render as ``$0.000``. + */ +function formatTurnCost(micros: number): string { + const dollars = micros / 1_000_000; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + if (dollars >= 0.01) return `$${dollars.toFixed(3)}`; + if (dollars > 0) return "<$0.001"; + return "$0"; +} + const MessageInfoDropdown: FC = () => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); @@ -452,6 +465,7 @@ const MessageInfoDropdown: FC = () => { {models.length > 0 ? ( models.map(([model, counts]) => { const { name, icon } = resolveModel(model); + const costMicros = counts.cost_micros; return ( { {counts.total_tokens.toLocaleString()} tokens + {costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""} ); @@ -475,6 +490,9 @@ const MessageInfoDropdown: FC = () => { > {usage.total_tokens.toLocaleString()} tokens + {usage.cost_micros && usage.cost_micros > 0 + ? ` · ${formatTurnCost(usage.cost_micros)}` + : ""} )} @@ -555,8 +573,10 @@ const AssistantMessageInner: FC = () => { )} -
- +
+
+ +
); @@ -649,35 +669,41 @@ export const AssistantMessage: FC = () => { className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150" data-role="assistant" > - {/* Comment trigger — right-aligned, just below user query on all screen sizes */} - {showCommentTrigger && ( -
- -
- )} + {/* Fixed trigger slot prevents any vertical reflow when visibility changes */} +
+ +
{/* Desktop floating comment panel — overlays on top of chat content */} {showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && ( diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx new file mode 100644 index 000000000..c0684407e --- /dev/null +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -0,0 +1,52 @@ +"use client"; + +import { ThreadPrimitive } from "@assistant-ui/react"; +import { ArrowDownIcon } from "lucide-react"; +import type { FC, ReactNode } from "react"; +import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; + +const ChatScrollToBottom: FC = () => ( + + + + + +); + +export interface ChatViewportProps { + children: ReactNode; + footer?: ReactNode; +} + +export const ChatViewport: FC = ({ children, footer }) => ( + +
+ {children} + {footer ? ( + +
+ + {footer} +
+
+ ) : null} + +); diff --git a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json index f62758256..b4e85eab0 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json +++ b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json @@ -9,6 +9,16 @@ "enabled": true, "status": "warning", "statusMessage": "Some requests may be blocked if not using Firecrawl." + }, + "JIRA_CONNECTOR": { + "enabled": false, + "status": "maintenance", + "statusMessage": "Rework in progress." + }, + "CONFLUENCE_CONNECTOR": { + "enabled": false, + "status": "maintenance", + "statusMessage": "Rework in progress." } }, "globalSettings": { diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index ae2c413cf..2f9605ea7 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -105,14 +105,14 @@ export const OAUTH_CONNECTORS = [ { id: "jira-connector", title: "Jira", - description: "Search, read, and manage issues", + description: "Rework in progress.", connectorType: EnumConnectorName.JIRA_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", }, { id: "confluence-connector", title: "Confluence", - description: "Search documentation", + description: "Rework in progress.", connectorType: EnumConnectorName.CONFLUENCE_CONNECTOR, authEndpoint: "/api/v1/auth/confluence/connector/add/", }, diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index 2aeba89ca..32a29cfc9 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -3,11 +3,11 @@ import { useQuery } from "@tanstack/react-query"; import { useSetAtom } from "jotai"; import { ExternalLink, FileText } from "lucide-react"; +import dynamic from "next/dynamic"; import type { FC } from "react"; import { useCallback, useEffect, useRef, useState } from "react"; import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context"; -import { MarkdownViewer } from "@/components/markdown-viewer"; import { Citation } from "@/components/tool-ui/citation"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; @@ -15,6 +15,16 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; +// Lazily load MarkdownViewer here to break the static import cycle: +// `markdown-viewer.tsx` → `citation-renderer.tsx` → `inline-citation.tsx` +// would otherwise pull `markdown-viewer.tsx` back in at module-init time. +// Only `SurfsenseDocCitation` (popover body) ever renders this viewer, so +// the lazy boundary is invisible to most call paths. +const MarkdownViewer = dynamic( + () => import("@/components/markdown-viewer").then((m) => m.MarkdownViewer), + { ssr: false, loading: () => } +); + interface InlineCitationProps { chunkId: number; isDocsChunk?: boolean; @@ -172,7 +182,7 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => {

)} {!isLoading && !error && citedChunk?.content && ( - + )} {!isLoading && !error && !citedChunk?.content && (

No content available.

diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index 05277f508..c585dc80f 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,26 +1,19 @@ "use client"; -import { X } from "lucide-react"; -import type { ReactElement } from "react"; +import type { PlateElementProps } from "platejs/react"; import { - createElement, - forwardRef, - useCallback, - useEffect, - useImperativeHandle, - useRef, - useState, -} from "react"; -import { renderToStaticMarkup } from "react-dom/server"; + createPlatePlugin, + ParagraphPlugin, + Plate, + PlateContent, + usePlateEditor, +} from "platejs/react"; +import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { cn } from "@/lib/utils"; -function renderElementToHTML(element: ReactElement): string { - return renderToStaticMarkup(element); -} - export interface MentionedDocument { id: number; title: string; @@ -61,38 +54,178 @@ interface InlineMentionEditorProps { initialText?: string; } -// Unique data attribute to identify chip elements -const CHIP_DATA_ATTR = "data-mention-chip"; -const CHIP_ID_ATTR = "data-mention-id"; -const CHIP_DOCTYPE_ATTR = "data-mention-doctype"; -const CHIP_STATUS_ATTR = "data-mention-status"; +type MentionStatusKind = "pending" | "processing" | "ready" | "failed"; +type ComposerTextNode = { text: string }; +type MentionElementNode = { + type: "mention"; + id: number; + title: string; + document_type?: string; + statusLabel?: string | null; + statusKind?: MentionStatusKind; + children: [{ text: "" }]; +}; +type ComposerNode = ComposerTextNode | MentionElementNode; +type ComposerParagraph = { type: "p"; children: ComposerNode[] }; +type ComposerValue = ComposerParagraph[]; + +const MENTION_TYPE = "mention"; +const MENTION_CHIP_CLASSNAME = + "inline-flex h-5 items-center gap-1 mx-0.5 rounded bg-primary/10 px-1 text-xs font-bold text-primary/60 select-none align-middle leading-none"; +const MENTION_CHIP_ICON_CLASSNAME = "flex items-center text-muted-foreground leading-none"; +const MENTION_CHIP_TITLE_CLASSNAME = "max-w-[120px] truncate leading-none"; +const COMPOSER_TEXT_METRICS_CLASSNAME = "text-sm leading-6"; + +const EMPTY_VALUE: ComposerValue = [{ type: "p", children: [{ text: "" }] }]; + +const MentionElement: FC> = ({ + attributes, + children, + element, +}) => { + const statusClass = + element.statusKind === "failed" + ? "text-destructive" + : element.statusKind === "ready" + ? "text-emerald-700" + : "text-amber-700"; -/** - * Type guard to check if a node is a chip element - */ -function isChipElement(node: Node | null): node is HTMLSpanElement { return ( - node !== null && - node.nodeType === Node.ELEMENT_NODE && - (node as Element).hasAttribute(CHIP_DATA_ATTR) + + + + {getConnectorIcon(element.document_type ?? "UNKNOWN", "h-3 w-3")} + + + {element.title} + + {element.statusLabel ? ( + + {element.statusLabel} + + ) : null} + + {children} + ); +}; + +const MentionPlugin = createPlatePlugin({ + key: MENTION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: MENTION_TYPE, + component: MentionElement, + }, +}); + +function isMentionNode(node: ComposerNode): node is MentionElementNode { + return typeof node === "object" && "type" in node && node.type === MENTION_TYPE; } -/** - * Safely parse chip ID from element attribute - */ -function getChipId(element: Element): number | null { - const idStr = element.getAttribute(CHIP_ID_ATTR); - if (!idStr) return null; - const id = parseInt(idStr, 10); - return Number.isNaN(id) ? null : id; +function getTextNode(node: ComposerNode): ComposerTextNode | null { + if (typeof node === "object" && "text" in node && typeof node.text === "string") return node; + return null; } -/** - * Get chip document type from element attribute - */ -function getChipDocType(element: Element): string { - return element.getAttribute(CHIP_DOCTYPE_ATTR) ?? "UNKNOWN"; +function toValueFromText(text: string): ComposerValue { + const lines = text.split("\n"); + if (lines.length === 0) return EMPTY_VALUE; + return lines.map((line) => ({ type: "p", children: [{ text: line }] })) as ComposerValue; +} + +function getPlainText(value: ComposerValue): string { + const lines = value.map((block) => + block.children + .map((node) => { + if (isMentionNode(node)) return `@${node.title}`; + return getTextNode(node)?.text ?? ""; + }) + .join("") + ); + return lines.join("\n").trim(); +} + +function getMentionedDocuments(value: ComposerValue): MentionedDocument[] { + const map = new Map(); + for (const block of value) { + for (const node of block.children) { + if (!isMentionNode(node)) continue; + const doc: MentionedDocument = { + id: node.id, + title: node.title, + document_type: node.document_type, + }; + map.set(getMentionDocKey(doc), doc); + } + } + return Array.from(map.values()); +} + +type EditorSelection = { + anchor: { path: number[]; offset: number }; + focus: { path: number[]; offset: number }; +} | null; + +function getCursorTextContext(value: ComposerValue, selection: EditorSelection) { + if (!selection || !selection.anchor || !selection.focus) return null; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] || + selection.anchor.path[1] !== selection.focus.path[1] + ) { + return null; + } + + const block = value[selection.anchor.path[0]]; + if (!block) return null; + const child = block.children[selection.anchor.path[1]]; + const textNode = getTextNode(child); + if (!textNode) return null; + + return { + blockIndex: selection.anchor.path[0], + childIndex: selection.anchor.path[1], + text: textNode.text, + cursor: selection.anchor.offset, + }; +} + +function scanActiveTrigger(text: string, cursor: number) { + let wordStart = 0; + for (let i = cursor - 1; i >= 0; i--) { + if (text[i] === " " || text[i] === "\n") { + wordStart = i + 1; + break; + } + } + + let triggerChar: "@" | "/" | null = null; + let triggerIndex = -1; + for (let i = wordStart; i < cursor; i++) { + if (text[i] === "@" || text[i] === "/") { + triggerChar = text[i] as "@" | "/"; + triggerIndex = i; + break; + } + } + if (!triggerChar || triggerIndex === -1) return null; + + const query = text.slice(triggerIndex + 1, cursor); + if (query.startsWith(" ")) return null; + if ( + triggerChar === "/" && + triggerIndex > 0 && + text[triggerIndex - 1] !== " " && + text[triggerIndex - 1] !== "\n" + ) { + return null; + } + + return { triggerChar, query }; } export const InlineMentionEditor = forwardRef( @@ -113,393 +246,163 @@ export const InlineMentionEditor = forwardRef { - const editorRef = useRef(null); - const [isEmpty, setIsEmpty] = useState(true); - const [mentionedDocs, setMentionedDocs] = useState>( - () => new Map() - ); - const isComposingRef = useRef(false); - const lastSelectionRangeRef = useRef(null); - const isRangeInsideEditor = useCallback((range: Range | null): range is Range => { - if (!range || !editorRef.current) return false; - return ( - editorRef.current.contains(range.startContainer) && - editorRef.current.contains(range.endContainer) - ); - }, []); - const isSelectionInsideEditor = useCallback( - (selection: Selection | null): selection is Selection => { - if (!selection || selection.rangeCount === 0 || !editorRef.current) return false; - const range = selection.getRangeAt(0); - return isRangeInsideEditor(range); - }, - [isRangeInsideEditor] - ); + const editableRef = useRef(null); + const editor = usePlateEditor({ + readOnly: disabled, + plugins: [ParagraphPlugin, MentionPlugin], + value: initialText ? toValueFromText(initialText) : EMPTY_VALUE, + }); - const rememberSelection = useCallback(() => { - const selection = window.getSelection(); - if (!isSelectionInsideEditor(selection)) return; - lastSelectionRangeRef.current = selection.getRangeAt(0).cloneRange(); - }, [isSelectionInsideEditor]); - - const restoreRememberedSelection = useCallback((): Selection | null => { - const selection = window.getSelection(); - if (!selection) return null; - if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null; - selection.removeAllRanges(); - selection.addRange(lastSelectionRangeRef.current.cloneRange()); - return selection; - }, [isRangeInsideEditor]); - - useEffect(() => { - const handleSelectionChange = () => { - if (document.activeElement !== editorRef.current) return; - rememberSelection(); - }; - document.addEventListener("selectionchange", handleSelectionChange); - return () => document.removeEventListener("selectionchange", handleSelectionChange); - }, [rememberSelection]); - - useEffect(() => { - if (!initialText || !editorRef.current) return; - editorRef.current.innerText = initialText; - editorRef.current.appendChild(document.createElement("br")); - editorRef.current.appendChild(document.createElement("br")); - setIsEmpty(false); - onChange?.(initialText, []); - editorRef.current.focus(); - const sel = window.getSelection(); - const range = document.createRange(); - range.selectNodeContents(editorRef.current); - range.collapse(false); - sel?.removeAllRanges(); - sel?.addRange(range); - const anchor = document.createElement("span"); - range.insertNode(anchor); - anchor.scrollIntoView({ block: "end" }); - anchor.remove(); - }, [initialText, onChange]); - - // Focus at the end of the editor const focusAtEnd = useCallback(() => { - if (!editorRef.current) return; - editorRef.current.focus(); + const el = editableRef.current; + if (!el) return; + el.focus(); const selection = window.getSelection(); const range = document.createRange(); - range.selectNodeContents(editorRef.current); + range.selectNodeContents(el); range.collapse(false); selection?.removeAllRanges(); selection?.addRange(range); }, []); - // Get plain text content with inline mention tokens for chips. - // This preserves the original query structure sent to the backend/LLM. - const getText = useCallback((): string => { - if (!editorRef.current) return ""; + const getCurrentValue = useCallback( + () => (editor.children as ComposerValue) ?? EMPTY_VALUE, + [editor] + ); - const extractText = (node: Node): string => { - if (node.nodeType === Node.TEXT_NODE) { - return node.textContent ?? ""; - } - - if (node.nodeType === Node.ELEMENT_NODE) { - const element = node as Element; - - // Preserve mention chips as inline @title tokens. - if (element.hasAttribute(CHIP_DATA_ATTR)) { - const title = element.querySelector("[data-mention-title='true']")?.textContent?.trim(); - if (title) { - return `@${title}`; - } - return ""; - } - - let result = ""; - for (const child of Array.from(element.childNodes)) { - result += extractText(child); - } - return result; - } - - return ""; - }; - - return extractText(editorRef.current).trim(); - }, []); - - // Get all mentioned documents - const getMentionedDocuments = useCallback((): MentionedDocument[] => { - return Array.from(mentionedDocs.values()); - }, [mentionedDocs]); - - const syncEditorState = useCallback( - (docsOverride?: Map) => { - const docs = docsOverride - ? Array.from(docsOverride.values()) - : Array.from(mentionedDocs.values()); - const text = getText(); - const empty = text.length === 0 && docs.length === 0; - setIsEmpty(empty); + const emitState = useCallback( + (nextValue: ComposerValue) => { + const text = getPlainText(nextValue); + const docs = getMentionedDocuments(nextValue); onChange?.(text, docs); - }, - [getText, mentionedDocs, onChange] - ); - // Create a chip element for a document - const createChipElement = useCallback( - (doc: MentionedDocument): HTMLSpanElement => { - const chip = document.createElement("span"); - chip.setAttribute(CHIP_DATA_ATTR, "true"); - chip.setAttribute(CHIP_ID_ATTR, String(doc.id)); - chip.setAttribute(CHIP_DOCTYPE_ATTR, doc.document_type ?? "UNKNOWN"); - chip.contentEditable = "false"; - chip.className = - "inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none cursor-default"; - chip.style.userSelect = "none"; - chip.style.verticalAlign = "baseline"; - - // Container that swaps between icon and remove button on hover - const iconContainer = document.createElement("span"); - iconContainer.className = "shrink-0 flex items-center size-3 relative"; - - const iconSpan = document.createElement("span"); - iconSpan.className = "flex items-center text-muted-foreground"; - iconSpan.innerHTML = renderElementToHTML( - getConnectorIcon(doc.document_type ?? "UNKNOWN", "h-3 w-3") - ); - - const removeBtn = document.createElement("button"); - removeBtn.type = "button"; - removeBtn.className = - "size-3 items-center justify-center rounded-full text-muted-foreground transition-colors"; - removeBtn.style.display = "none"; - removeBtn.innerHTML = renderElementToHTML( - createElement(X, { className: "h-3 w-3", strokeWidth: 2.5 }) - ); - removeBtn.onclick = (e) => { - e.preventDefault(); - e.stopPropagation(); - chip.remove(); - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(docKey); - syncEditorState(next); - return next; - }); - onDocumentRemove?.(doc.id, doc.document_type); - focusAtEnd(); - }; - - const titleSpan = document.createElement("span"); - titleSpan.className = "max-w-[120px] truncate"; - titleSpan.textContent = doc.title; - titleSpan.title = doc.title; - titleSpan.setAttribute("data-mention-title", "true"); - - const statusSpan = document.createElement("span"); - statusSpan.setAttribute(CHIP_STATUS_ATTR, "true"); - statusSpan.className = "text-[10px] font-semibold opacity-80 hidden"; - - const isTouchDevice = window.matchMedia("(hover: none)").matches; - if (isTouchDevice) { - // Mobile: icon on left, title, X on right - chip.appendChild(iconSpan); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); - removeBtn.style.display = "flex"; - removeBtn.className += " ml-0.5"; - chip.appendChild(removeBtn); - } else { - // Desktop: icon/X swap on hover in the same slot - iconContainer.appendChild(iconSpan); - iconContainer.appendChild(removeBtn); - chip.addEventListener("mouseenter", () => { - iconSpan.style.display = "none"; - removeBtn.style.display = "flex"; - }); - chip.addEventListener("mouseleave", () => { - iconSpan.style.display = ""; - removeBtn.style.display = "none"; - }); - chip.appendChild(iconContainer); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); + const cursorCtx = getCursorTextContext(nextValue, editor.selection); + if (!cursorCtx) { + onMentionClose?.(); + onActionClose?.(); + return; } - return chip; + const trigger = scanActiveTrigger(cursorCtx.text, cursorCtx.cursor); + if (!trigger) { + onMentionClose?.(); + onActionClose?.(); + return; + } + + if (trigger.triggerChar === "@") { + onMentionTrigger?.(trigger.query); + onActionClose?.(); + return; + } + + onActionTrigger?.(trigger.query); + onMentionClose?.(); }, - [focusAtEnd, onDocumentRemove, syncEditorState] + [editor.selection, onActionClose, onActionTrigger, onChange, onMentionClose, onMentionTrigger] + ); + + const setValue = useCallback( + (nextValue: ComposerValue) => { + const tf = editor.tf as { setValue: (value: ComposerValue) => void }; + tf.setValue(nextValue); + emitState(nextValue); + }, + [editor, emitState] ); - // Insert a document chip at the current cursor position const insertDocumentChip = useCallback( ( doc: Pick, options?: { removeTriggerText?: boolean } ) => { - if (!editorRef.current) return; + if (typeof doc.id !== "number" || typeof doc.title !== "string") return; + const removeTriggerText = options?.removeTriggerText ?? true; - - // Validate required fields for type safety - if (typeof doc.id !== "number" || typeof doc.title !== "string") { - console.warn("[InlineMentionEditor] Invalid document passed to insertDocumentChip:", doc); - return; - } - - const mentionDoc: MentionedDocument = { + const current = getCurrentValue(); + const selection = editor.selection; + const mentionNode: MentionElementNode = { + type: MENTION_TYPE, id: doc.id, title: doc.title, document_type: doc.document_type, + children: [{ text: "" }], }; - // Add to mentioned docs map using unique key - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => new Map(prev).set(docKey, mentionDoc)); - const nextDocs = new Map(mentionedDocs); - nextDocs.set(docKey, mentionDoc); - - // Find and remove the @query text - const selection = window.getSelection(); - const hasActiveSelection = isSelectionInsideEditor(selection); - const resolvedSelection = hasActiveSelection ? selection : restoreRememberedSelection(); - if ( - !resolvedSelection || - resolvedSelection.rangeCount === 0 || - !isSelectionInsideEditor(resolvedSelection) - ) { - // No valid in-editor selection: deterministically insert at end. - editorRef.current.focus(); - const endSelection = window.getSelection(); - if (!endSelection) return; - const endRange = document.createRange(); - endRange.selectNodeContents(editorRef.current); - endRange.collapse(false); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - const chip = createChipElement(mentionDoc); - endRange.insertNode(chip); - endRange.setStartAfter(chip); - endRange.collapse(true); - const space = document.createTextNode(" "); - endRange.insertNode(space); - endRange.setStartAfter(space); - endRange.collapse(true); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - syncEditorState(nextDocs); - rememberSelection(); + const cursorCtx = getCursorTextContext(current, selection); + if (!cursorCtx) { + const lastBlock = current[current.length - 1] ?? { type: "p", children: [{ text: "" }] }; + const appended: ComposerValue = [ + ...current.slice(0, -1), + { + ...lastBlock, + children: [...lastBlock.children, mentionNode, { text: " " }], + }, + ]; + setValue(appended); + requestAnimationFrame(focusAtEnd); return; } - // Find the @ symbol before the cursor and remove it along with any query text - const range = resolvedSelection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE && removeTriggerText) { - const text = textNode.textContent || ""; - const cursorPos = range.startOffset; - - // Find the @ symbol before cursor - let atIndex = -1; - for (let i = cursorPos - 1; i >= 0; i--) { - if (text[i] === "@") { - atIndex = i; - break; - } - } - - if (atIndex !== -1) { - // Remove @query and insert chip - const beforeAt = text.slice(0, atIndex); - const afterCursor = text.slice(cursorPos); - - // Create chip - const chip = createChipElement(mentionDoc); - - // Replace text node content - const parent = textNode.parentNode; - if (parent) { - const beforeNode = document.createTextNode(beforeAt); - const afterNode = document.createTextNode(` ${afterCursor}`); - - parent.insertBefore(beforeNode, textNode); - parent.insertBefore(chip, textNode); - parent.insertBefore(afterNode, textNode); - parent.removeChild(textNode); - - // Set cursor after the chip - const newRange = document.createRange(); - newRange.setStart(afterNode, 1); - newRange.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(newRange); - rememberSelection(); - } - } else { - // No @ found, just insert at cursor - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - - // Add space after chip - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); - } - } else { - // Either explicit non-trigger insertion or no @query present. - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); + const block = current[cursorCtx.blockIndex]; + const currentChild = getTextNode(block.children[cursorCtx.childIndex]); + if (!currentChild) { + const children = [...block.children]; + children.splice(cursorCtx.childIndex + 1, 0, mentionNode, { text: " " }); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); + return; } - syncEditorState(nextDocs); + const text = currentChild.text; + let removeStart = cursorCtx.cursor; + if (removeTriggerText) { + for (let i = cursorCtx.cursor - 1; i >= 0; i--) { + if (text[i] === "@") { + removeStart = i; + break; + } + if (text[i] === " " || text[i] === "\n") break; + } + } + + const before = text.slice(0, removeStart); + const after = text.slice(cursorCtx.cursor); + const replacement: ComposerNode[] = []; + if (before.length > 0) replacement.push({ text: before }); + replacement.push(mentionNode); + replacement.push({ text: ` ${after}` }); + + const children = [...block.children]; + children.splice(cursorCtx.childIndex, 1, ...replacement); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); }, - [ - createChipElement, - isSelectionInsideEditor, - mentionedDocs, - rememberSelection, - restoreRememberedSelection, - syncEditorState, - ] + [editor.selection, focusAtEnd, getCurrentValue, setValue] ); - // Clear the editor - const clear = useCallback(() => { - if (editorRef.current) { - editorRef.current.innerHTML = ""; - const emptyDocs = new Map(); - setMentionedDocs(emptyDocs); - syncEditorState(emptyDocs); - } - }, [syncEditorState]); - - // Replace editor content with plain text and place cursor at end - const setText = useCallback( - (text: string) => { - if (!editorRef.current) return; - editorRef.current.innerText = text; - syncEditorState(); - focusAtEnd(); + const removeDocumentChip = useCallback( + (docId: number, docType?: string) => { + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => { + const children = block.children.filter((node) => { + if (!isMentionNode(node)) return true; + const match = + node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (match) changed = true; + return !match; + }); + return { ...block, children: children.length ? children : [{ text: "" }] }; + }); + if (!changed) return; + setValue(next as ComposerValue); }, - [focusAtEnd, syncEditorState] + [getCurrentValue, setValue] ); const setDocumentChipStatus = useCallback( @@ -507,327 +410,143 @@ export const InlineMentionEditor = forwardRef { - if (!editorRef.current) return; - - const chips = editorRef.current.querySelectorAll( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - const chipId = getChipId(chip); - const chipType = getChipDocType(chip); - if (chipId !== docId) continue; - if ((docType ?? "UNKNOWN") !== chipType) continue; - - const statusEl = chip.querySelector(`span[${CHIP_STATUS_ATTR}="true"]`); - if (!statusEl) continue; - - if (!statusLabel) { - statusEl.textContent = ""; - statusEl.className = "text-[10px] font-semibold opacity-80 hidden"; - continue; - } - - const statusClass = - statusKind === "failed" - ? "text-destructive" - : statusKind === "processing" - ? "text-amber-700" - : statusKind === "ready" - ? "text-emerald-700" - : "text-amber-700"; - statusEl.textContent = statusLabel; - statusEl.className = `text-[10px] font-semibold opacity-80 ${statusClass}`; - } + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => ({ + ...block, + children: block.children.map((node) => { + if (!isMentionNode(node)) return node; + const sameType = (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (node.id !== docId || !sameType) return node; + changed = true; + return { + ...node, + statusLabel, + statusKind: statusLabel ? statusKind : undefined, + }; + }), + })); + if (!changed) return; + setValue(next as ComposerValue); }, - [] + [getCurrentValue, setValue] ); - const removeDocumentChip = useCallback( - (docId: number, docType?: string) => { - if (!editorRef.current) return; - const chipKey = getMentionDocKey({ id: docId, document_type: docType }); - const chips = editorRef.current.querySelectorAll( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - if (getChipId(chip) === docId && getChipDocType(chip) === (docType ?? "UNKNOWN")) { - chip.remove(); - break; - } - } - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); + const clear = useCallback(() => { + setValue(EMPTY_VALUE); + }, [setValue]); + + const setText = useCallback( + (text: string) => { + setValue(toValueFromText(text)); + requestAnimationFrame(focusAtEnd); }, - [syncEditorState] + [focusAtEnd, setValue] ); - // Expose methods via ref - useImperativeHandle(ref, () => ({ - focus: () => editorRef.current?.focus(), - clear, - setText, - getText, - getMentionedDocuments, - insertDocumentChip, - removeDocumentChip, - setDocumentChipStatus, - })); + const getText = useCallback(() => getPlainText(getCurrentValue()), [getCurrentValue]); + const getMentionedDocs = useCallback( + () => getMentionedDocuments(getCurrentValue()), + [getCurrentValue] + ); - // Handle input changes - const handleInput = useCallback(() => { - if (!editorRef.current) return; + useImperativeHandle( + ref, + () => ({ + focus: () => editableRef.current?.focus(), + clear, + setText, + getText, + getMentionedDocuments: getMentionedDocs, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + }), + [ + clear, + getMentionedDocs, + getText, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + setText, + ] + ); - const text = getText(); - const empty = text.length === 0 && mentionedDocs.size === 0; - setIsEmpty(empty); - - // Unified trigger scan: find the leftmost @ or / in the current word. - // Whichever trigger was typed first owns the token — the other character - // is treated as part of the query, not as a separate trigger. - const selection = window.getSelection(); - let shouldTriggerMention = false; - let mentionQuery = ""; - let shouldTriggerAction = false; - let actionQuery = ""; - - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE) { - const textContent = textNode.textContent || ""; - const cursorPos = range.startOffset; - - let wordStart = 0; - for (let i = cursorPos - 1; i >= 0; i--) { - if (textContent[i] === " " || textContent[i] === "\n") { - wordStart = i + 1; - break; - } - } - - let triggerChar: "@" | "/" | null = null; - let triggerIndex = -1; - for (let i = wordStart; i < cursorPos; i++) { - if (textContent[i] === "@" || textContent[i] === "/") { - triggerChar = textContent[i] as "@" | "/"; - triggerIndex = i; - break; - } - } - - if (triggerChar === "@" && triggerIndex !== -1) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerMention = true; - mentionQuery = query; - } - } else if (triggerChar === "/" && triggerIndex !== -1) { - if ( - triggerIndex === 0 || - textContent[triggerIndex - 1] === " " || - textContent[triggerIndex - 1] === "\n" - ) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerAction = true; - actionQuery = query; - } - } - } - } - } - - // If no @ found before cursor, check if text contains @ at all - // If text is empty or doesn't contain @, close the mention - if (!shouldTriggerMention) { - if (text.length === 0 || !text.includes("@")) { - onMentionClose?.(); - } else { - // Text contains @ but not before cursor, close mention - onMentionClose?.(); - } - } else { - onMentionTrigger?.(mentionQuery); - } - - if (!shouldTriggerAction) { - onActionClose?.(); - } else { - onActionTrigger?.(actionQuery); - } - - // Notify parent of change - onChange?.(text, Array.from(mentionedDocs.values())); - rememberSelection(); - }, [ - getText, - mentionedDocs, - onChange, - onMentionTrigger, - onMentionClose, - onActionTrigger, - onActionClose, - rememberSelection, - ]); - - // Handle keydown const handleKeyDown = useCallback( (e: React.KeyboardEvent) => { - // Let parent handle navigation keys when mention popover is open - if (onKeyDown) { - onKeyDown(e); - if (e.defaultPrevented) return; - } + onKeyDown?.(e); + if (e.defaultPrevented) return; - // Handle Enter for submit (without shift) if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); onSubmit?.(); return; } - // Handle backspace on chips - if (e.key === "Backspace") { - const selection = window.getSelection(); - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - if (range.collapsed) { - // Check if cursor is right after a chip - const node = range.startContainer; - const offset = range.startOffset; - - if (node.nodeType === Node.TEXT_NODE && offset === 0) { - // Check previous sibling using type guard - const prevSibling = node.previousSibling; - if (isChipElement(prevSibling)) { - e.preventDefault(); - const chipId = getChipId(prevSibling); - const chipDocType = getChipDocType(prevSibling); - if (chipId !== null) { - prevSibling.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - return; - } - // Check if we're about to delete @ at the start - const textContent = node.textContent || ""; - if (textContent.length > 0 && textContent[0] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.TEXT_NODE && offset > 0) { - // Check if we're about to delete @ - const textContent = node.textContent || ""; - if (textContent[offset - 1] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.ELEMENT_NODE && offset > 0) { - // Check if previous child is a chip using type guard - const prevChild = (node as Element).childNodes[offset - 1]; - if (isChipElement(prevChild)) { - e.preventDefault(); - const chipId = getChipId(prevChild); - const chipDocType = getChipDocType(prevChild); - if (chipId !== null) { - prevChild.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - } - } - } - } + if (e.key !== "Backspace") return; + const selection = editor.selection; + if (!selection || !selection.anchor || !selection.focus) return; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] + ) { + return; } + if (selection.anchor.offset !== 0 || selection.focus.offset !== 0) return; + + const value = getCurrentValue(); + const block = value[selection.anchor.path[0]]; + if (!block) return; + const childIndex = selection.anchor.path[1]; + if (childIndex <= 0) return; + const prev = block.children[childIndex - 1]; + if (!isMentionNode(prev)) return; + + e.preventDefault(); + removeDocumentChip(prev.id, prev.document_type); + onDocumentRemove?.(prev.id, prev.document_type); }, - [onKeyDown, onSubmit, onDocumentRemove, onMentionClose, syncEditorState] + [editor.selection, getCurrentValue, onDocumentRemove, onKeyDown, onSubmit, removeDocumentChip] ); - // Handle paste - strip formatting - const handlePaste = useCallback((e: React.ClipboardEvent) => { - e.preventDefault(); - const text = e.clipboardData.getData("text/plain"); - document.execCommand("insertText", false, text); - }, []); - - // Handle composition (for IME input) - const handleCompositionStart = useCallback(() => { - isComposingRef.current = true; - }, []); - - const handleCompositionEnd = useCallback(() => { - isComposingRef.current = false; - handleInput(); - }, [handleInput]); + const editableProps = useMemo( + () => ({ + placeholder, + onPaste: (e: React.ClipboardEvent) => { + e.preventDefault(); + const text = e.clipboardData.getData("text/plain"); + const tf = editor.tf as { insertText: (value: string) => void }; + tf.insertText(text); + }, + onKeyDown: handleKeyDown, + }), + [editor, handleKeyDown, placeholder] + ); return (
- {/* biome-ignore lint/a11y/noStaticElementInteractions: contenteditable mention editor requires a div for inline chips */} -
- {/* Placeholder with fade animation on change */} - {isEmpty && ( - - )} + { + emitState(value as ComposerValue); + }} + > + +
); } diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 7655e10cc..9fddec360 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -12,14 +12,15 @@ import { ExternalLinkIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useParams } from "next/navigation"; import { useTheme } from "next-themes"; -import { memo, type ReactNode } from "react"; +import { createContext, memo, type ReactNode, useCallback, useContext, useRef } from "react"; import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; -import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { toast } from "sonner"; +import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; import { Table, @@ -30,6 +31,8 @@ import { TableRow, } from "@/components/ui/table"; import { useElectronAPI } from "@/hooks/use-platform"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; +import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; function MarkdownCodeBlockSkeleton() { @@ -59,31 +62,30 @@ const LazyMarkdownCodeBlock = dynamic( } ); -// Storage for URL citations replaced during preprocess to avoid GFM autolink interference. -// Populated in preprocessMarkdown, consumed in parseTextWithCitations. -let _pendingUrlCitations = new Map(); -let _urlCiteIdx = 0; +// Per-render URL placeholder map propagated to component overrides via +// React Context. Replaces the previous module-level `_pendingUrlCitations` +// state, which was unsafe under concurrent renders / SSR. +type CitationUrlMapRef = { current: CitationUrlMap }; +const EMPTY_URL_MAP: CitationUrlMap = new Map(); +const CitationUrlMapContext = createContext({ current: EMPTY_URL_MAP }); + +function useCitationUrlMap(): CitationUrlMap { + return useContext(CitationUrlMapContext).current; +} /** * Preprocess raw markdown before it reaches the remark/rehype pipeline. * - Replaces URL-based citations with safe placeholders (prevents GFM autolinks) * - Normalises LaTeX delimiters to dollar-sign syntax for remark-math */ -function preprocessMarkdown(content: string): string { +function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): string { // Replace URL-based citations with safe placeholders BEFORE markdown parsing. // GFM autolinks would otherwise convert the https://... inside [citation:URL] // into an element, splitting the text and preventing our citation regex // from matching the full pattern. - _pendingUrlCitations = new Map(); - _urlCiteIdx = 0; - content = content.replace( - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g, - (_, url) => { - const key = `urlcite${_urlCiteIdx++}`; - _pendingUrlCitations.set(key, url.trim()); - return `[citation:${key}]`; - } - ); + const { content: rewritten, urlMap } = preprocessCitationMarkdown(content); + urlMapRef.current = urlMap; + content = rewritten; // All math forms are normalised to $$...$$ so we can disable single-dollar // inline math in remark-math (otherwise currency like "$3,120.00 and $0.00" @@ -116,113 +118,25 @@ function preprocessMarkdown(content: string): string { return content; } -// Matches [citation:...] with numeric IDs (incl. negative, doc- prefix, comma-separated), -// URL-based IDs from live web search, or urlciteN placeholders from preprocess. -// Also matches Chinese brackets 【】 and handles zero-width spaces that LLM sometimes inserts. -const CITATION_REGEX = - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g; - -/** - * Parses text and replaces [citation:XXX] patterns with citation components. - * Supports: - * - Numeric chunk IDs: [citation:123] - * - Doc-prefixed IDs: [citation:doc-123] - * - Comma-separated IDs: [citation:4149, 4150, 4151] - * - URL-based citations from live search: [citation:https://example.com/page] - */ -function parseTextWithCitations(text: string): ReactNode[] { - const parts: ReactNode[] = []; - let lastIndex = 0; - let match: RegExpExecArray | null; - let instanceIndex = 0; - - CITATION_REGEX.lastIndex = 0; - - match = CITATION_REGEX.exec(text); - while (match !== null) { - if (match.index > lastIndex) { - parts.push(text.substring(lastIndex, match.index)); - } - - const captured = match[1]; - - if (captured.startsWith("http://") || captured.startsWith("https://")) { - parts.push(); - instanceIndex++; - } else if (captured.startsWith("urlcite")) { - const url = _pendingUrlCitations.get(captured); - if (url) { - parts.push(); - } - instanceIndex++; - } else { - const rawIds = captured.split(",").map((s) => s.trim()); - for (const rawId of rawIds) { - const isDocsChunk = rawId.startsWith("doc-"); - const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10); - parts.push( - - ); - instanceIndex++; - } - } - - lastIndex = match.index + match[0].length; - match = CITATION_REGEX.exec(text); - } - - if (lastIndex < text.length) { - parts.push(text.substring(lastIndex)); - } - - return parts.length > 0 ? parts : [text]; -} - const MarkdownTextImpl = () => { + const urlMapRef = useRef(EMPTY_URL_MAP); + const preprocess = useCallback((content: string) => preprocessMarkdown(content, urlMapRef), []); return ( - + + + ); }; export const MarkdownText = memo(MarkdownTextImpl); -/** - * Helper to process children and replace citation patterns with components - */ -function processChildrenWithCitations(children: ReactNode): ReactNode { - if (typeof children === "string") { - const parsed = parseTextWithCitations(children); - return parsed.length === 1 && typeof parsed[0] === "string" ? children : parsed; - } - - if (Array.isArray(children)) { - return children.map((child) => { - if (typeof child === "string") { - const parsed = parseTextWithCitations(child); - return parsed.length === 1 && typeof parsed[0] === "string" ? ( - child - ) : ( - {parsed} - ); - } - return child; - }); - } - - return children; -} - function extractDomain(url: string): string { try { const parsed = new URL(url); @@ -282,6 +196,85 @@ function isVirtualFilePathToken(value: string): boolean { return segments.length >= 2; } +function isStandaloneDocumentsPathText(node: ReactNode): string | null { + if (typeof node !== "string") return null; + const value = node.trim(); + if (!value.startsWith("/documents/")) return null; + if (value.includes(" ")) return null; + const normalized = value.replace(/\/+$/, ""); + const leaf = normalized.split("/").filter(Boolean).at(-1) ?? ""; + if (!leaf || !leaf.includes(".")) return null; + return value; +} + +function FilePathLink({ path, className }: { path: string; className?: string }) { + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const params = useParams(); + const electronAPI = useElectronAPI(); + const searchSpaceIdParam = params?.search_space_id; + const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) + ? Number(searchSpaceIdParam[0]) + : Number(searchSpaceIdParam); + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) + ? parsedSearchSpaceId + : undefined; + + return ( + + ); +} + function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { if (!src) return null; @@ -322,92 +315,127 @@ function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { } const defaultComponents = memoizeMarkdownComponents({ - h1: ({ className, children, ...props }) => ( -

- {processChildrenWithCitations(children)} -

- ), - h2: ({ className, children, ...props }) => ( -

- {processChildrenWithCitations(children)} -

- ), - h3: ({ className, children, ...props }) => ( -

- {processChildrenWithCitations(children)} -

- ), - h4: ({ className, children, ...props }) => ( -

- {processChildrenWithCitations(children)} -

- ), - h5: ({ className, children, ...props }) => ( -
- {processChildrenWithCitations(children)} -
- ), - h6: ({ className, children, ...props }) => ( -
- {processChildrenWithCitations(children)} -
- ), - p: ({ className, children, ...props }) => ( -

- {processChildrenWithCitations(children)} -

- ), - a: ({ className, children, ...props }) => ( -
- {processChildrenWithCitations(children)} - - ), - blockquote: ({ className, children, ...props }) => ( -
- {processChildrenWithCitations(children)} -
- ), + h1: function H1({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +

+ {processChildrenWithCitations(children, urlMap)} +

+ ); + }, + h2: function H2({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +

+ {processChildrenWithCitations(children, urlMap)} +

+ ); + }, + h3: function H3({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +

+ {processChildrenWithCitations(children, urlMap)} +

+ ); + }, + h4: function H4({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +

+ {processChildrenWithCitations(children, urlMap)} +

+ ); + }, + h5: function H5({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +
+ {processChildrenWithCitations(children, urlMap)} +
+ ); + }, + h6: function H6({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +
+ {processChildrenWithCitations(children, urlMap)} +
+ ); + }, + p: function P({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + const standalonePath = isStandaloneDocumentsPathText(children); + return ( +

+ {standalonePath ? ( + + ) : ( + processChildrenWithCitations(children, urlMap) + )} +

+ ); + }, + a: function A({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + + {processChildrenWithCitations(children, urlMap)} + + ); + }, + blockquote: function Blockquote({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +
+ {processChildrenWithCitations(children, urlMap)} +
+ ); + }, ul: ({ className, ...props }) => (
    li]:mt-2", className)} {...props} /> ), ol: ({ className, ...props }) => (
      li]:mt-2", className)} {...props} /> ), - li: ({ className, children, ...props }) => ( -
    1. - {processChildrenWithCitations(children)} -
    2. - ), + li: function Li({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( +
    3. + {processChildrenWithCitations(children, urlMap)} +
    4. + ); + }, hr: ({ className, ...props }) => (
      ), @@ -422,28 +450,34 @@ const defaultComponents = memoizeMarkdownComponents({ tbody: ({ className, ...props }) => ( ), - th: ({ className, children, ...props }) => ( - - {processChildrenWithCitations(children)} - - ), - td: ({ className, children, ...props }) => ( - - {processChildrenWithCitations(children)} - - ), + th: function Th({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + + {processChildrenWithCitations(children, urlMap)} + + ); + }, + td: function Td({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + + {processChildrenWithCitations(children, urlMap)} + + ); + }, tr: ({ className, ...props }) => , sup: ({ className, ...props }) => ( a]:text-xs [&>a]:no-underline", className)} {...props} /> @@ -452,8 +486,6 @@ const defaultComponents = memoizeMarkdownComponents({ code: function Code({ className, children, ...props }) { const isCodeBlock = useIsMarkdownCodeBlock(); const { resolvedTheme } = useTheme(); - const openEditorPanel = useSetAtom(openEditorPanelAtom); - const params = useParams(); const electronAPI = useElectronAPI(); const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text"; const codeString = String(children).replace(/\n$/, ""); @@ -470,53 +502,17 @@ const defaultComponents = memoizeMarkdownComponents({ const isLikelyFolder = inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes("."); const isLocalPath = - !!electronAPI && - isVirtualFilePathToken(inlineValue) && - !inlineValue.startsWith("//") && - !isLikelyFolder; - const displayLocalPath = inlineValue.replace(/^\/+/, ""); - const searchSpaceIdParam = params?.search_space_id; - const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) - ? Number(searchSpaceIdParam[0]) - : Number(searchSpaceIdParam); + (isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder && + !!electronAPI) || + (isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder && + !electronAPI && + inlineValue.startsWith("/documents/")); if (isLocalPath) { - return ( - - ); + return ; } return ( ); }, - strong: ({ className, children, ...props }) => ( - - {processChildrenWithCitations(children)} - - ), - em: ({ className, children, ...props }) => ( - - {processChildrenWithCitations(children)} - - ), + strong: function Strong({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + + {processChildrenWithCitations(children, urlMap)} + + ); + }, + em: function Em({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + + {processChildrenWithCitations(children, urlMap)} + + ); + }, img: ({ src, alt }) => ( ), diff --git a/surfsense_web/components/assistant-ui/nested-scroll.tsx b/surfsense_web/components/assistant-ui/nested-scroll.tsx new file mode 100644 index 000000000..37c4790df --- /dev/null +++ b/surfsense_web/components/assistant-ui/nested-scroll.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { type ComponentPropsWithoutRef, forwardRef, type WheelEvent } from "react"; + +export type NestedScrollProps = ComponentPropsWithoutRef<"div">; + +export const NestedScroll = forwardRef( + ({ onWheel, ...props }, ref) => { + const handleWheel = (event: WheelEvent) => { + const el = event.currentTarget; + const canScrollUp = el.scrollTop > 0; + const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1; + const goingUp = event.deltaY < 0; + const goingDown = event.deltaY > 0; + if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) { + event.stopPropagation(); + } + onWheel?.(event); + }; + return
      ; + } +); + +NestedScroll.displayName = "NestedScroll"; diff --git a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx b/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx deleted file mode 100644 index 394ba5d79..000000000 --- a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import { ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; -import type { FC } from "react"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; - -export const ThreadScrollToBottom: FC = () => { - return ( - - - - - - ); -}; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index e58783c87..b4a3b58c6 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -5,12 +5,10 @@ import { ThreadPrimitive, useAui, useAuiState, - useThreadViewportStore, } from "@assistant-ui/react"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { AlertCircle, - ArrowDownIcon, ArrowUpIcon, Camera, ChevronDown, @@ -37,10 +35,13 @@ import { toggleToolAtom, } from "@/atoms/agent-tools/agent-tools.atoms"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; -import { - mentionedDocumentsAtom, -} from "@/atoms/chat/mentioned-documents.atom"; +import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; +import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; +import { + clearPremiumAlertForThreadAtom, + premiumAlertByThreadAtom, +} from "@/atoms/chat/premium-alert.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; @@ -52,6 +53,7 @@ import { import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; import { @@ -90,8 +92,8 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsSync } from "@/hooks/use-comments-sync"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; -import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; import { cn } from "@/lib/utils"; @@ -109,10 +111,13 @@ const ThreadContent: FC = () => { ["--thread-max-width" as string]: "44rem", }} > - !thread.isEmpty}> + + + + } > thread.isEmpty}> @@ -125,36 +130,39 @@ const ThreadContent: FC = () => { AssistantMessage, }} /> - - !thread.isEmpty}> -
      - - - - - !thread.isEmpty}> - - - - + ); }; -const ThreadScrollToBottom: FC = () => { +const PremiumQuotaPinnedAlert: FC = () => { + const currentThreadState = useAtomValue(currentThreadAtom); + const alertsByThread = useAtomValue(premiumAlertByThreadAtom); + const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom); + + const currentThreadId = currentThreadState?.id; + if (!currentThreadId) return null; + + const alert = alertsByThread[currentThreadId]; + if (!alert) return null; + return ( - - - - - +
      +
      + +
      +

      {alert.message}

      +
      + +
      +
      ); }; @@ -374,23 +382,9 @@ const Composer: FC = () => { >(new Map()); const documentPickerRef = useRef(null); const promptPickerRef = useRef(null); - const viewportRef = useRef(null); const { search_space_id, chat_id } = useParams(); const aui = useAui(); - const threadViewportStore = useThreadViewportStore(); const hasAutoFocusedRef = useRef(false); - const submitCleanupRef = useRef<(() => void) | null>(null); - - useEffect(() => { - return () => { - submitCleanupRef.current?.(); - }; - }, []); - - // Store viewport element reference on mount - useEffect(() => { - viewportRef.current = document.querySelector(".aui-thread-viewport"); - }, []); const electronAPI = useElectronAPI(); const [clipboardInitialText, setClipboardInitialText] = useState(); @@ -589,7 +583,6 @@ const Composer: FC = () => { [showDocumentPopover, showPromptPicker] ); - // Submit message (blocked during streaming, document picker open, or AI responding to another user) const handleSubmit = useCallback(() => { if (isThreadRunning || isBlockedByOtherUser) return; if (showDocumentPopover || showPromptPicker) return; @@ -601,50 +594,9 @@ const Composer: FC = () => { setClipboardInitialText(undefined); } - const viewportEl = viewportRef.current; - const heightBefore = viewportEl?.scrollHeight ?? 0; - aui.composer().send(); editorRef.current?.clear(); setMentionedDocuments([]); - - // With turnAnchor="top", ViewportSlack adds min-height to the last - // assistant message so that scrolling-to-bottom actually positions the - // user message at the TOP of the viewport. That slack height is - // calculated asynchronously (ResizeObserver → style → layout). - // Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes. - const scrollToBottom = () => - threadViewportStore.getState().scrollToBottom({ behavior: "instant" }); - - let lastHeight = heightBefore; - let frames = 0; - let cancelled = false; - const POLL_FRAMES = 30; - - const pollAndScroll = () => { - if (cancelled) return; - const el = viewportRef.current; - if (el) { - const h = el.scrollHeight; - if (h !== lastHeight) { - lastHeight = h; - scrollToBottom(); - } - } - if (++frames < POLL_FRAMES) { - requestAnimationFrame(pollAndScroll); - } - }; - requestAnimationFrame(pollAndScroll); - - const t1 = setTimeout(scrollToBottom, 100); - const t2 = setTimeout(scrollToBottom, 300); - - submitCleanupRef.current = () => { - cancelled = true; - clearTimeout(t1); - clearTimeout(t2); - }; }, [ showDocumentPopover, showPromptPicker, @@ -653,7 +605,6 @@ const Composer: FC = () => { clipboardInitialText, aui, setMentionedDocuments, - threadViewportStore, ]); const handleDocumentRemove = useCallback( diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index b3f71ab21..dd80bcac3 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -13,13 +13,30 @@ export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Populated from LiteLLM's response_cost on + * the backend. Optional because pre-cost-credits messages persisted + * before the migration won't have it. + */ + cost_micros?: number; usage?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; } diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 66e2ebd4a..06082c9c7 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,13 +1,11 @@ -import { - type ToolCallMessagePartComponent, - useAuiState, -} from "@assistant-ui/react"; +import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react"; import { useQueryClient } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; +import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { DoomLoopApprovalToolUI, isDoomLoopInterrupt, @@ -31,10 +29,7 @@ import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/component import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; import { getToolDisplayName } from "@/contracts/enums/toolIcons"; -import { - markActionRevertedInCache, - useAgentActionsQuery, -} from "@/hooks/use-agent-actions-query"; +import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; @@ -123,8 +118,7 @@ function ToolCardRevertButton({ // Tier 1 + 2: O(1) Map-backed direct id match. Covers // ~all parity_v2 streams and any legacy stream that backfilled // ``langchainToolCallId`` via ``tool-output-available``. - const direct = - findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); + const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); if (direct) return direct; // Tier 3: position-within-turn fallback. Only kicks in when the // card has a synthetic ``call_`` id AND no @@ -159,12 +153,7 @@ function ToolCardRevertButton({ setIsReverting(true); try { const response = await agentActionsApiService.revert(threadId, action.id); - markActionRevertedInCache( - queryClient, - threadId, - action.id, - response.new_action_id ?? null - ); + markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null); toast.success(response.message || "Action reverted."); } catch (err) { // 503 means revert is gated off on this deployment — hide the @@ -475,7 +464,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { {(argsText || isRunning) && (

      Inputs

      -
      + {argsText ? (
       											{argsText}
      @@ -489,7 +478,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
       											Waiting for input…
       										

      )} -
      +
      )} {!isCancelled && result !== undefined && ( @@ -497,11 +486,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {

      Result

      -
      +
       											{typeof result === "string" ? result : serializedResult}
       										
      -
      +
      )} diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index fb7212119..145ac2d7e 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,4 +1,10 @@ -import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react"; +import { + ActionBarPrimitive, + AuiIf, + MessagePrimitive, + useAuiState, + useMessagePartText, +} from "@assistant-ui/react"; import { useAtomValue } from "jotai"; import { CheckIcon, CopyIcon, Pencil } from "lucide-react"; import Image from "next/image"; @@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; +import { parseMentionSegments } from "@/lib/chat/parse-mention-segments"; interface AuthorMetadata { displayName: string | null; @@ -47,23 +55,40 @@ const UserAvatar: FC = ({ displayName, avatarUrl }) => { ); }; -export const UserMessage: FC = () => { +const UserTextPart: FC = () => { const messageId = useAuiState(({ message }) => message?.id); - const messageText = useAuiState(({ message }) => - (message?.content ?? []) - .map((part) => - typeof part === "object" && - part !== null && - "type" in part && - (part as { type?: string }).type === "text" && - "text" in part - ? String((part as { text?: string }).text ?? "") - : "" - ) - .join("") - ); + const part = useMessagePartText(); + const text = (part as { text?: string }).text ?? ""; const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); - const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; + const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? []; + + const segments = parseMentionSegments(text, mentionedDocs); + + return ( +

      + {segments.map((segment) => + segment.type === "text" ? ( + {segment.value} + ) : ( + + + {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} + + {segment.doc.title} + + ) + )} +

      + ); +}; + +const userMessageParts = { Text: UserTextPart }; + +export const UserMessage: FC = () => { const metadata = useAuiState(({ message }) => message?.metadata); const author = metadata?.custom?.author as AuthorMetadata | undefined; const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE"; @@ -78,11 +103,7 @@ export const UserMessage: FC = () => {
      - {mentionedDocs && mentionedDocs.length > 0 ? ( - - ) : ( - - )} +
      @@ -99,64 +120,6 @@ export const UserMessage: FC = () => { ); }; -const UserMessageWithMentionChips: FC<{ - text: string; - mentionedDocs: { id: number; title: string; document_type: string }[]; -}> = ({ text, mentionedDocs }) => { - type Segment = - | { type: "text"; value: string; start: number } - | { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number }; - - const tokens = mentionedDocs - .map((doc) => ({ doc, token: `@${doc.title}` })) - .sort((a, b) => b.token.length - a.token.length); - - const segments: Segment[] = []; - let i = 0; - let buffer = ""; - let bufferStart = 0; - while (i < text.length) { - const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); - if (tokenMatch) { - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - buffer = ""; - } - segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); - i += tokenMatch.token.length; - bufferStart = i; - continue; - } - if (!buffer) bufferStart = i; - buffer += text[i]; - i += 1; - } - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - } - - return ( - - {segments.map((segment) => - segment.type === "text" ? ( - {segment.value} - ) : ( - - - {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} - - {segment.doc.title} - - ) - )} - - ); -}; - const UserActionBar: FC = () => { const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); diff --git a/surfsense_web/components/citation-panel/citation-panel.tsx b/surfsense_web/components/citation-panel/citation-panel.tsx index cec07b9cf..ed8acd656 100644 --- a/surfsense_web/components/citation-panel/citation-panel.tsx +++ b/surfsense_web/components/citation-panel/citation-panel.tsx @@ -169,7 +169,7 @@ export const CitationPanelContent: FC = ({ chunkId, o )}
      - +
      ); diff --git a/surfsense_web/components/citations/citation-renderer.tsx b/surfsense_web/components/citations/citation-renderer.tsx new file mode 100644 index 000000000..f2de4b27d --- /dev/null +++ b/surfsense_web/components/citations/citation-renderer.tsx @@ -0,0 +1,77 @@ +"use client"; + +import type { ReactNode } from "react"; +import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { + type CitationToken, + type CitationUrlMap, + parseTextWithCitations, +} from "@/lib/citations/citation-parser"; + +/** + * Render a single parsed citation token as JSX. + * + * `ordinalKey` should be a stable per-render counter so duplicate identical + * citations within the same parent don't collide on `key`. The previous + * implementation in `markdown-text.tsx` used the source string itself as + * the key, which produced React warnings when two segments rendered the + * same `[citation:N]` text. + */ +export function renderCitationToken(token: CitationToken, ordinalKey: number): ReactNode { + if (token.kind === "url") { + return ; + } + return ( + + ); +} + +/** + * Walk a `ReactNode` (string, array, or arbitrary node) and replace any + * `[citation:...]` tokens inside string children with citation badges. + * + * Designed for use inside `Streamdown`/`react-markdown` `components` + * overrides where the renderer hands you `children`. Non-string children + * are returned untouched so block/phrasing structure is preserved. + */ +export function processChildrenWithCitations( + children: ReactNode, + urlMap: CitationUrlMap +): ReactNode { + if (typeof children === "string") { + const segments = parseTextWithCitations(children, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return children; + } + let ordinal = 0; + return segments.map((segment) => + typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++) + ); + } + + if (Array.isArray(children)) { + let ordinal = 0; + return children.map((child, childIndex) => { + if (typeof child === "string") { + const segments = parseTextWithCitations(child, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return child; + } + return ( + + {segments.map((segment) => + typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++) + )} + + ); + } + return child; + }); + } + + return children; +} diff --git a/surfsense_web/components/document-viewer.tsx b/surfsense_web/components/document-viewer.tsx index 0f283e567..710a04ba3 100644 --- a/surfsense_web/components/document-viewer.tsx +++ b/surfsense_web/components/document-viewer.tsx @@ -32,7 +32,7 @@ export function DocumentViewer({ title, content, trigger }: DocumentViewerProps) {title}
      - +
      diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index df138e97e..eab07a91b 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -652,7 +652,7 @@ export function EditorPanelContent({ // Plate is heavy on multi-MB docs.
      {largeDocAlert} - +
      ) : renderInPlateEditor ? ( // Editable doc (FILE/NOTE) — Plate editing UX. @@ -670,12 +670,17 @@ export function EditorPanelContent({ reserveToolbarSpace defaultEditing={isEditing} className="**:[[role=toolbar]]:bg-sidebar!" + // Render `[citation:N]` badges in view mode only. + // Edit mode keeps raw text so the user can edit/delete + // tokens directly. `local_file` never reaches this branch + // (handled by the source_code editor above). + enableCitations={!isEditing && !isLocalFileMode} />
      ) : (
      - +
      )}
      diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 7f12d3cae..c42cb991e 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -8,9 +8,11 @@ import { useEffect, useMemo, useRef } from "react"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; import { EditorSaveContext } from "@/components/editor/editor-save-context"; +import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit"; import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { Editor, EditorContainer } from "@/components/ui/editor"; +import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; /** Live editor instance returned by `usePlateEditor`. */ export type PlateEditorInstance = ReturnType; @@ -65,6 +67,14 @@ export interface PlateEditorProps { * without modifying the core editor component. */ extraPlugins?: AnyPluginConfig[]; + /** + * Render `[citation:N]` and `[citation:URL]` tokens in the deserialized + * markdown as interactive citation badges/popovers (mirrors chat). Only + * meant for read-only views — when true, `onMarkdownChange` is suppressed + * because the in-memory tree contains custom inline-void elements that + * have no markdown serialize rule. + */ + enableCitations?: boolean; } function PlateEditorContent({ @@ -103,6 +113,7 @@ export function PlateEditor({ defaultEditing = false, preset = "full", extraPlugins = [], + enableCitations = false, }: PlateEditorProps) { const lastMarkdownRef = useRef(markdown); const lastHtmlRef = useRef(html); @@ -145,6 +156,8 @@ export function PlateEditor({ ...(onSave ? [SaveShortcutPlugin] : []), // Consumer-provided extra plugins ...extraPlugins, + // Citation void inline element (read-only document viewer). + ...(enableCitations ? CitationKit : []), MarkdownPlugin.configure({ options: { remarkPlugins: [remarkGfm, remarkMath, remarkMdx], @@ -154,8 +167,18 @@ export function PlateEditor({ value: html ? (editor) => editor.api.html.deserialize({ element: html }) as Value : markdown - ? (editor) => - editor.getApi(MarkdownPlugin).markdown.deserialize(escapeMdxExpressions(markdown)) + ? (editor) => { + if (!enableCitations) { + return editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(markdown)); + } + const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); + const value = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(rewritten)); + return injectCitationNodes(value as Descendant[], urlMap) as Value; + } : undefined, }); @@ -174,13 +197,22 @@ export function PlateEditor({ useEffect(() => { if (!html && markdown !== undefined && markdown !== lastMarkdownRef.current) { lastMarkdownRef.current = markdown; - const newValue = editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(markdown)); + let newValue: Descendant[]; + if (enableCitations) { + const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); + const deserialized = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(rewritten)) as Descendant[]; + newValue = injectCitationNodes(deserialized, urlMap); + } else { + newValue = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(markdown)) as Descendant[]; + } editor.tf.reset(); - editor.tf.setValue(newValue); + editor.tf.setValue(newValue as Value); } - }, [html, markdown, editor]); + }, [html, markdown, editor, enableCitations]); // When not forced read-only, the user can toggle between editing/viewing. const canToggleMode = !readOnly && allowModeToggle; @@ -205,6 +237,16 @@ export function PlateEditor({ // (initialized to true via usePlateEditor, toggled via ModeToolbarButton). {...(readOnly ? { readOnly: true } : {})} onChange={({ value }) => { + // View-only citation mode: skip serialization. The custom + // `citation` inline-void element has no markdown serialize + // rule, so emitting changes here would overwrite + // `lastMarkdownRef.current` (and downstream copy-to-clipboard + // state in EditorPanelContent) with a tree that loses every + // citation token. `enableCitations` is only ever set in + // read-only paths, so user input cannot reach this branch + // in practice — the guard exists for the initial Plate + // normalize emit. + if (enableCitations) return; if (onHtmlChange && html) { const serialized = slateToHtml(value as Descendant[]); onHtmlChange(serialized); diff --git a/surfsense_web/components/editor/plugins/citation-kit.tsx b/surfsense_web/components/editor/plugins/citation-kit.tsx new file mode 100644 index 000000000..1908de209 --- /dev/null +++ b/surfsense_web/components/editor/plugins/citation-kit.tsx @@ -0,0 +1,218 @@ +"use client"; + +import { type Descendant, KEYS } from "platejs"; +import { createPlatePlugin, type PlateElementProps } from "platejs/react"; +import type { FC } from "react"; +import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { + CITATION_REGEX, + type CitationUrlMap, + parseTextWithCitations, +} from "@/lib/citations/citation-parser"; + +/** + * Plate inline-void node modeling a single `[citation:...]` reference. + * + * Modeled after the existing `MentionPlugin` pattern in + * `inline-mention-editor.tsx` — the only confirmed pattern in this repo + * for non-text inline UI. Inline-void elements satisfy Slate's invariant + * that the editor renders both atomic widgets and surrounding text + * cleanly without breaking selection / caret semantics. + */ +export type CitationElementNode = { + type: "citation"; + kind: "chunk" | "doc" | "url"; + chunkId?: number; + url?: string; + /** Original `[citation:...]` substring for traceability/debugging. */ + rawText: string; + children: [{ text: "" }]; +}; + +const CITATION_TYPE = "citation"; + +const CitationElement: FC> = ({ + attributes, + children, + element, +}) => { + const isUrl = element.kind === "url"; + return ( + + + {isUrl && element.url ? ( + + ) : element.chunkId !== undefined ? ( + + ) : null} + + {children} + + ); +}; + +const CitationPlugin = createPlatePlugin({ + key: CITATION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: CITATION_TYPE, + component: CitationElement, + }, +}); + +/** Plugin kit shape used elsewhere in the editor. */ +export const CitationKit = [CitationPlugin]; + +// --------------------------------------------------------------------------- +// Slate value transform — runs after MarkdownPlugin.deserialize +// --------------------------------------------------------------------------- + +// Structural shapes used by the value transform. We cannot use Plate's +// generic Element / Text type predicates directly because `Descendant` is a +// constrained union and our predicates would over-narrow. Casting through +// these row types keeps the walker readable without fighting the types. +type SlateText = { text: string } & Record; +type SlateElement = { type?: string; children: Descendant[] } & Record; + +function isText(node: Descendant): boolean { + return typeof (node as { text?: unknown }).text === "string"; +} + +function asText(node: Descendant): SlateText { + return node as unknown as SlateText; +} + +function asElement(node: Descendant): SlateElement { + return node as unknown as SlateElement; +} + +/** + * Element types whose subtrees we MUST NOT inject citation void elements + * into. Each rationale documented in the citation plan: + * - `KEYS.codeBlock` / `code_line` — Plate's schema rejects inline elements + * inside code containers; the user expects literal text inside code. + * - `KEYS.link` — `