diff --git a/VERSION b/VERSION index 44517d518..fe04e7f67 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.19 +0.0.20 diff --git a/docker/.env.example b/docker/.env.example index 95de0cf85..fd56bdccc 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_BATCH_SIZE=100 -# Premium token purchases ($1 per 1M tokens for premium-tier models) +# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of +# credit; premium turns debit the actual per-call provider cost +# reported by LiteLLM, so cheap and expensive models bill proportionally) # STRIPE_TOKEN_BUYING_ENABLED=FALSE # STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -# STRIPE_TOKENS_PER_UNIT=1000000 +# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000 # ------------------------------------------------------------------------------ # TTS & STT (Text-to-Speech / Speech-to-Text) @@ -305,6 +308,24 @@ STT_SERVICE=local/base # Advanced (optional) # ------------------------------------------------------------------------------ +# New-chat agent feature flags +SURFSENSE_ENABLE_CONTEXT_EDITING=true +SURFSENSE_ENABLE_COMPACTION_V2=true +SURFSENSE_ENABLE_RETRY_AFTER=true +SURFSENSE_ENABLE_MODEL_FALLBACK=false +SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true +SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true +SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true +SURFSENSE_ENABLE_BUSY_MUTEX=true +SURFSENSE_ENABLE_SKILLS=true +SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true +SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true +SURFSENSE_ENABLE_ACTION_LOG=true +SURFSENSE_ENABLE_REVERT_ROUTE=true +SURFSENSE_ENABLE_PERMISSION=true +SURFSENSE_ENABLE_DOOM_LOOP=true +SURFSENSE_ENABLE_STREAM_PARITY_V2=true + # Periodic connector sync interval (default: 5m) # SCHEDULE_CHECKER_INTERVAL=5m @@ -315,9 +336,24 @@ STT_SERVICE=local/base # Pages limit per user for ETL (default: unlimited) # PAGES_LIMIT=500 -# Premium token quota per registered user (default: 5M) -# Only applies to models with billing_tier=premium in global_llm_config.yaml -# PREMIUM_TOKEN_LIMIT=5000000 +# Premium credit quota per registered user, in micro-USD (default: $5). +# Premium turns are debited at the actual per-call provider cost reported +# by LiteLLM. Only applies to models with billing_tier=premium. +# PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default). +# QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default). +# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation for the podcast Celery task ($0.20 default). +# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation for the video Celery task ($1.00 default). +# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — public users can chat without an account # Set TRUE to enable /free pages and anonymous chat API diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index a793f33d1..86c1b326f 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000 # Set FALSE to disable new checkout session creation temporarily STRIPE_PAGE_BUYING_ENABLED=TRUE -# Premium token purchases via Stripe (for premium-tier model usage) -# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens) +# Premium credit purchases via Stripe (for premium-tier model usage). +# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit +# (default 1_000_000 = $1.00). Premium turns are billed at the actual +# per-call provider cost reported by LiteLLM. STRIPE_TOKEN_BUYING_ENABLED=FALSE STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -STRIPE_TOKENS_PER_UNIT=1000000 +STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping): +# STRIPE_TOKENS_PER_UNIT=1000000 # Periodic Stripe safety net for purchases left in PENDING (minutes old) STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 @@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300 # (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) PAGES_LIMIT=500 -# Premium token quota per registered user (default: 3,000,000) -# Applies only to models with billing_tier=premium in global_llm_config.yaml -PREMIUM_TOKEN_LIMIT=3000000 +# Premium credit quota per registered user, in micro-USD +# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the +# actual per-call provider cost reported by LiteLLM, so cheap and expensive +# models bill proportionally. Applies only to models with +# billing_tier=premium in global_llm_config.yaml. +PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping): +# PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD. +# stream_new_chat estimates an upper-bound cost from the model's +# litellm-published per-token rates × the config's quota_reserve_tokens +# and clamps to this value so a misconfigured model can't lock the +# user's whole balance on one call. Default $1.00. +QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation (in micro-USD) for the POST /image-generations +# endpoint. Bypassed for free configs. Default $0.05. +QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation (in micro-USD) used by the podcast Celery task. +# Single envelope covers one transcript-generation LLM call. Default $0.20. +QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation (in micro-USD) used by the video +# presentation Celery task. Covers worst-case fan-out of N slide-scene +# generations + refines. Default $1.00. NOTE: tasks using the override +# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — allows public users to chat without an account # Set TRUE to enable /free pages and anonymous chat API @@ -294,3 +324,30 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names # SURFSENSE_ALLOWED_PLUGINS=year_substituter + +# ----------------------------------------------------------------------------- +# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON) +# ----------------------------------------------------------------------------- +# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU +# on a cold turn) is reused across subsequent turns on the same thread, +# collapsing it to a microsecond hash lookup. All connector tools acquire +# their own short-lived DB session per call (Phase 2 refactor) so a cached +# closure is safe to share across requests. Flip OFF only as a last-resort +# rollback if you suspect cache-related staleness. +# SURFSENSE_ENABLE_AGENT_CACHE=true + +# Cache capacity (max number of compiled-agent entries kept in memory) +# and TTL per entry (seconds). Working set is typically one entry per +# active thread on this replica; tune up for very large deployments. +# SURFSENSE_AGENT_CACHE_MAXSIZE=256 +# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800 + +# ----------------------------------------------------------------------------- +# Connector discovery TTL cache (Phase 1.4 perf optimization) +# ----------------------------------------------------------------------------- +# Caches the per-search-space "available connectors" + "available document +# types" lookups that ``create_surfsense_deep_agent`` hits on every turn. +# ORM event listeners auto-invalidate on connector / document inserts, +# updates and deletes — the TTL only bounds staleness for bulk-import +# paths that bypass the ORM. Set to 0 to disable the cache. +# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30 diff --git a/surfsense_backend/Dockerfile b/surfsense_backend/Dockerfile index 1222b36b6..73d5819b9 100644 --- a/surfsense_backend/Dockerfile +++ b/surfsense_backend/Dockerfile @@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs COPY pyproject.toml . COPY uv.lock . -# Install PyTorch based on architecture -RUN if [ "$(uname -m)" = "x86_64" ]; then \ - pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \ - else \ - pip install --no-cache-dir torch torchvision torchaudio; \ - fi - -# Install python dependencies +# Install all Python dependencies from uv.lock for deterministic builds. +# +# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock, +# which lets prod silently drift to newer upstream versions on every rebuild +# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports). +# Exporting the lock to requirements.txt and feeding it to `uv pip install` +# pins every transitive package to the exact version captured in uv.lock. +# +# Note on torch/CUDA: we do NOT install torch from a separate cu* index here. +# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull +# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all +# captured in uv.lock). Installing from cu121 first only wasted ~2GB of +# downloads that the lock-based install immediately replaced. If a specific +# CUDA version is needed (driver compatibility, etc.), wire it through +# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth. RUN pip install --no-cache-dir uv && \ - uv pip install --system --no-cache-dir -e . + uv export --frozen --no-dev --no-hashes --no-emit-project \ + --format requirements-txt -o /tmp/requirements.txt && \ + uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \ + rm /tmp/requirements.txt # Set SSL environment variables dynamically RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \ @@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr # Pre-download Docling models RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true -# Install Playwright browsers for web scraping if needed -RUN pip install playwright && \ - playwright install chromium --with-deps +# Install Playwright browsers for web scraping (the playwright package itself +# is already installed via uv.lock above) +RUN playwright install chromium --with-deps # Copy source code COPY . . +# Install the project itself in editable mode. Dependencies were already +# installed deterministically from uv.lock above, so --no-deps prevents any +# re-resolution that could pull newer versions. +RUN uv pip install --system --no-cache-dir --no-deps -e . + # Copy and set permissions for entrypoint script # Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts) COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py new file mode 100644 index 000000000..64aa699e8 --- /dev/null +++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py @@ -0,0 +1,291 @@ +"""rename premium token columns to credit micros and add cost_micros to token_usage + +Migrates the premium quota system from a flat token counter to a USD-cost +based credit system, where 1 credit = 1 micro-USD ($0.000001). + +Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe +price means every existing value is already correct in the new unit, no +data transformation needed): + + user.premium_tokens_limit -> premium_credit_micros_limit + user.premium_tokens_used -> premium_credit_micros_used + user.premium_tokens_reserved -> premium_credit_micros_reserved + + premium_token_purchases.tokens_granted -> credit_micros_granted + +New column for cost auditing per turn: + + token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0) + +The "user" table is in zero_publication's column list (added in 139), so +this migration must drop and recreate the publication with the renamed +column names, otherwise zero-cache will replicate stale column names and +the FE Zero schema will fail to bind. + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on +"user". Skipping the data-volume reset will leave IndexedDB clients seeing +column-not-found errors from a stale catalog snapshot. + +Revision ID: 140 +Revises: 139 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "140" +down_revision: str | None = "139" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Replicates 139's document column list verbatim — must stay in sync. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Same five live-meter fields as 139, with the renamed column names. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_credit_micros_limit", + "premium_credit_micros_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _column_exists(conn, table: str, column: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = :col" + ), + {"tbl": table, "col": column}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl( + user_cols: list[str], + *, + documents_has_zero_ver: bool, + user_has_zero_ver: bool, +) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_col_list_with_meta = user_cols + ( + ['"_0_version"'] if user_has_zero_ver else [] + ) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_col_list_with_meta) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def upgrade() -> None: + conn = op.get_bind() + + # ------------------------------------------------------------------ + # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in + # dev environments are safe. + # ------------------------------------------------------------------ + if not _column_exists(conn, "token_usage", "cost_micros"): + op.add_column( + "token_usage", + sa.Column( + "cost_micros", + sa.BigInteger(), + nullable=False, + server_default="0", + ), + ) + + # ------------------------------------------------------------------ + # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted. + # ------------------------------------------------------------------ + if _column_exists( + conn, "premium_token_purchases", "tokens_granted" + ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"): + op.alter_column( + "premium_token_purchases", + "tokens_granted", + new_column_name="credit_micros_granted", + ) + + # ------------------------------------------------------------------ + # 3. Rename user.premium_tokens_* -> premium_credit_micros_*. + # + # We must drop the publication first (it references the old column + # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE + # in a transaction block; alembic's outer transaction already holds + # one, but a SAVEPOINT keeps the LOCK + DDL atomic. + # ------------------------------------------------------------------ + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list + # publications require at least the PK to be in the column list, + # which is true for both the old and new shape. + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + # Drop the publication BEFORE renaming columns, otherwise Postgres + # rejects the rename: "cannot drop column ... referenced by + # publication". + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for old, new in ( + ("premium_tokens_limit", "premium_credit_micros_limit"), + ("premium_tokens_used", "premium_credit_micros_used"), + ("premium_tokens_reserved", "premium_credit_micros_reserved"), + ): + if _column_exists(conn, "user", old) and not _column_exists( + conn, "user", new + ): + op.alter_column("user", old, new_column_name=new) + + # Update the server_default on the renamed limit column so newly + # inserted users get $5 of credit (== 5_000_000 micros) by + # default. Existing rows are unaffected. + op.alter_column( + "user", + "premium_credit_micros_limit", + server_default="5000000", + ) + + # Recreate the publication with the new column names. + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + USER_COLS, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + +def downgrade() -> None: + """Revert the rename and drop ``cost_micros``. + + Mirrors ``upgrade``: drop the publication, rename columns back, drop + the new column, recreate the publication with the old column list. + Same zero-cache stop/reset runbook applies in reverse. + """ + conn = op.get_bind() + + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for new, old in ( + ("premium_credit_micros_limit", "premium_tokens_limit"), + ("premium_credit_micros_used", "premium_tokens_used"), + ("premium_credit_micros_reserved", "premium_tokens_reserved"), + ): + if _column_exists(conn, "user", new) and not _column_exists( + conn, "user", old + ): + op.alter_column("user", new, new_column_name=old) + + op.alter_column( + "user", + "premium_tokens_limit", + server_default="5000000", + ) + + legacy_user_cols = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", + ] + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + legacy_user_cols, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + if _column_exists( + conn, "premium_token_purchases", "credit_micros_granted" + ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"): + op.alter_column( + "premium_token_purchases", + "credit_micros_granted", + new_column_name="tokens_granted", + ) + + if _column_exists(conn, "token_usage", "cost_micros"): + op.drop_column("token_usage", "cost_micros") diff --git a/surfsense_backend/app/agents/new_chat/agent_cache.py b/surfsense_backend/app/agents/new_chat/agent_cache.py new file mode 100644 index 000000000..fa8e6fb72 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/agent_cache.py @@ -0,0 +1,357 @@ +"""TTL-LRU cache for compiled SurfSense deep agents. + +Why this exists +--------------- + +``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat +turn: + +1. Discover connectors & document types from Postgres (~50-200ms) +2. Build the tool list (built-in + MCP) (~200ms-1.7s) +3. Compose the system prompt +4. Construct ~15 middleware instances (CPU) +5. Eagerly compile the general-purpose subagent + (``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously, + which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure + CPU work) +6. Compile the outer LangGraph + +For a single thread, all six steps produce the SAME object on every turn +unless the user has changed their LLM config, toggled a feature flag, +added a connector, etc. The right answer is to compile ONCE per +"agent shape" and reuse the resulting :class:`CompiledStateGraph` for +every subsequent turn on the same thread. + +Why a per-thread key (not a global pool) +---------------------------------------- + +Most middleware in the SurfSense stack captures per-thread state in +``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``, +``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse +would silently leak state across users and threads. Keying the cache on +``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated +turns on the same thread without changing any middleware's behavior. + +Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema` +(read via ``runtime.context``) so the cache can collapse to a single +``(llm_config_id, search_space_id, ...)`` key shared across threads. Until +then, per-thread keying is the only safe option. + +Cache shape +----------- + +* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30 + minutes — matches a typical chat session). ``maxsize`` (default 256) + caps memory; LRU evicts least-recently-used on overflow. +* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent + cold misses on the same key wait for the first build instead of + building N times. +* Process-local: this is an in-memory cache. Multi-replica deployments + pay the build cost once per replica per key. That's fine; the working + set per replica is small (one entry per active thread on that replica). + +Telemetry +--------- + +Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``: + + * ``hit`` — cache hit, microseconds-fast + * ``miss`` — first build for this key, includes build duration + * ``stale`` — entry was found but expired; rebuilt + * ``evict`` — LRU eviction (size-limited) + * ``size`` — current cache occupancy at lookup time +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import os +import time +from collections import OrderedDict +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from app.utils.perf import get_perf_logger + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + + +# --------------------------------------------------------------------------- +# Public API: signature helpers (cache key components) +# --------------------------------------------------------------------------- + + +def stable_hash(*parts: Any) -> str: + """Compute a deterministic SHA1 of the str repr of ``parts``. + + Used for cache key components that need a fixed-width representation + (system prompt, tool list, etc.). SHA1 is fine here — this is not a + security boundary, just a content fingerprint. + """ + h = hashlib.sha1(usedforsecurity=False) + for p in parts: + h.update(repr(p).encode("utf-8", errors="replace")) + h.update(b"\x1f") # ASCII unit separator between parts + return h.hexdigest() + + +def tools_signature( + tools: list[Any] | tuple[Any, ...], + *, + available_connectors: list[str] | None, + available_document_types: list[str] | None, +) -> str: + """Hash the bound-tool surface for cache-key purposes. + + The signature changes whenever: + + * A tool is added or removed from the bound list (built-in toggles, + MCP tools loaded for the user changes, gating rules flip, etc.). + * The available connectors / document types for the search space + change (new connector added, last connector removed, new document + type indexed). Because :func:`get_connector_gated_tools` derives + ``modified_disabled_tools`` from ``available_connectors``, the + tool surface is technically already covered — but we hash the + connector list separately so an empty-list "no tools changed" + situation still rotates the key when, say, the user re-adds a + connector that gates a tool we were already not exposing. + + Stays stable across: + + * Process restarts (tool names + descriptions are static). + * Different replicas (everyone gets the same hash for the same + inputs). + """ + tool_descriptors = sorted( + (getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools + ) + connectors = sorted(available_connectors or []) + doc_types = sorted(available_document_types or []) + return stable_hash(tool_descriptors, connectors, doc_types) + + +def flags_signature(flags: Any) -> str: + """Hash the resolved :class:`AgentFeatureFlags` dataclass. + + Frozen dataclasses are deterministically reprable, so a SHA1 of their + repr is a stable fingerprint. Restart safe (flags are read once at + process boot). + """ + return stable_hash(repr(flags)) + + +def system_prompt_hash(system_prompt: str) -> str: + """Hash a system prompt string. Cheap, ~30µs for typical prompts.""" + return hashlib.sha1( + system_prompt.encode("utf-8", errors="replace"), + usedforsecurity=False, + ).hexdigest() + + +# --------------------------------------------------------------------------- +# Cache implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _Entry: + value: Any + created_at: float + last_used_at: float + + +class _AgentCache: + """In-process TTL-LRU cache with per-key in-flight de-duplication. + + NOT THREAD-SAFE in the multithreading sense — designed for a single + asyncio event loop. Uvicorn runs one event loop per worker process, + so this is fine; multi-worker deployments simply each maintain their + own cache. + """ + + def __init__(self, *, maxsize: int, ttl_seconds: float) -> None: + self._maxsize = maxsize + self._ttl = ttl_seconds + self._entries: OrderedDict[str, _Entry] = OrderedDict() + # One lock per key — guards "build" so concurrent cold misses on + # the same key wait for the first build instead of all racing. + self._locks: dict[str, asyncio.Lock] = {} + + def _now(self) -> float: + return time.monotonic() + + def _is_fresh(self, entry: _Entry) -> bool: + return (self._now() - entry.created_at) < self._ttl + + def _evict_if_full(self) -> None: + while len(self._entries) >= self._maxsize: + evicted_key, _ = self._entries.popitem(last=False) + self._locks.pop(evicted_key, None) + _perf_log.info( + "[agent_cache] evict key=%s reason=lru size=%d", + _short(evicted_key), + len(self._entries), + ) + + def _touch(self, key: str, entry: _Entry) -> None: + entry.last_used_at = self._now() + self._entries.move_to_end(key, last=True) + + async def get_or_build( + self, + key: str, + *, + builder: Callable[[], Awaitable[Any]], + ) -> Any: + """Return the cached value for ``key`` or call ``builder()`` to make it. + + ``builder`` MUST be idempotent — concurrent cold misses on the + same key collapse to a single ``builder()`` call (the others + wait on the in-flight lock and observe the populated entry on + wake). + """ + # Fast path: hot hit. + entry = self._entries.get(key) + if entry is not None and self._is_fresh(entry): + self._touch(key, entry) + _perf_log.info( + "[agent_cache] hit key=%s age=%.1fs size=%d", + _short(key), + self._now() - entry.created_at, + len(self._entries), + ) + return entry.value + + # Stale entry — drop it; rebuild below. + if entry is not None and not self._is_fresh(entry): + _perf_log.info( + "[agent_cache] stale key=%s age=%.1fs ttl=%.0fs", + _short(key), + self._now() - entry.created_at, + self._ttl, + ) + self._entries.pop(key, None) + + # Slow path: serialize concurrent misses for the same key. + lock = self._locks.setdefault(key, asyncio.Lock()) + async with lock: + # Double-check after acquiring the lock — another waiter may + # have populated the entry while we slept. + entry = self._entries.get(key) + if entry is not None and self._is_fresh(entry): + self._touch(key, entry) + _perf_log.info( + "[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true", + _short(key), + self._now() - entry.created_at, + len(self._entries), + ) + return entry.value + + t0 = time.perf_counter() + try: + value = await builder() + except BaseException: + # Don't cache failed builds; let the next caller retry. + _perf_log.warning( + "[agent_cache] build_failed key=%s elapsed=%.3fs", + _short(key), + time.perf_counter() - t0, + ) + raise + elapsed = time.perf_counter() - t0 + + # Insert + evict. + self._evict_if_full() + now = self._now() + self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now) + self._entries.move_to_end(key, last=True) + _perf_log.info( + "[agent_cache] miss key=%s build=%.3fs size=%d", + _short(key), + elapsed, + len(self._entries), + ) + return value + + def invalidate(self, key: str) -> bool: + """Drop a single entry; return True if anything was removed.""" + removed = self._entries.pop(key, None) is not None + self._locks.pop(key, None) + if removed: + _perf_log.info( + "[agent_cache] invalidate key=%s size=%d", + _short(key), + len(self._entries), + ) + return removed + + def invalidate_prefix(self, prefix: str) -> int: + """Drop every entry whose key starts with ``prefix``. Returns count.""" + keys = [k for k in self._entries if k.startswith(prefix)] + for k in keys: + self._entries.pop(k, None) + self._locks.pop(k, None) + if keys: + _perf_log.info( + "[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d", + _short(prefix), + len(keys), + len(self._entries), + ) + return len(keys) + + def clear(self) -> None: + n = len(self._entries) + self._entries.clear() + self._locks.clear() + if n: + _perf_log.info("[agent_cache] clear removed=%d", n) + + def stats(self) -> dict[str, Any]: + return { + "size": len(self._entries), + "maxsize": self._maxsize, + "ttl_seconds": self._ttl, + } + + +def _short(key: str, n: int = 16) -> str: + """Truncate keys for log lines so they don't blow up log volume.""" + return key if len(key) <= n else f"{key[:n]}..." + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- + +_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256")) +_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800")) + +_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL) + + +def get_cache() -> _AgentCache: + """Return the process-wide compiled-agent cache singleton.""" + return _cache + + +def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache: + """Replace the singleton with a fresh cache. Tests only.""" + global _cache + _cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds) + return _cache + + +__all__ = [ + "flags_signature", + "get_cache", + "reload_for_tests", + "stable_hash", + "system_prompt_hash", + "tools_signature", +] diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index c0e9a3b96..36739adae 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.agent_cache import ( + flags_signature, + get_cache, + stable_hash, + system_prompt_hash, + tools_signature, +) from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver @@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import ( DedupHITLToolCallsMiddleware, DoomLoopMiddleware, FileIntentMiddleware, + FlattenSystemMessageMiddleware, KnowledgeBasePersistenceMiddleware, KnowledgePriorityMiddleware, KnowledgeTreeMiddleware, @@ -330,23 +338,39 @@ async def create_surfsense_deep_agent( else None, ) - # Discover available connectors and document types for this search space + # Discover available connectors and document types for this search space. + # + # NOTE: These two calls cannot be parallelized via ``asyncio.gather``. + # ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``); + # SQLAlchemy explicitly forbids concurrent operations on the same session + # ("This session is provisioning a new connection; concurrent operations + # are not permitted on the same session"). The Phase 1.4 in-process TTL + # cache in ``connector_service`` already collapses the warm path to a + # near-zero pair of dict lookups, so sequential awaits cost nothing in + # the common case while remaining correct on cold cache misses. available_connectors: list[str] | None = None available_document_types: list[str] | None = None _t0 = time.perf_counter() try: - connector_types = await connector_service.get_available_connectors( - search_space_id - ) - if connector_types: - available_connectors = _map_connectors_to_searchable_types(connector_types) + try: + connector_types_result = await connector_service.get_available_connectors( + search_space_id + ) + if connector_types_result: + available_connectors = _map_connectors_to_searchable_types( + connector_types_result + ) + except Exception as e: + logging.warning("Failed to discover available connectors: %s", e) - available_document_types = await connector_service.get_available_document_types( - search_space_id - ) - - except Exception as e: + try: + available_document_types = ( + await connector_service.get_available_document_types(search_space_id) + ) + except Exception as e: + logging.warning("Failed to discover available document types: %s", e) + except Exception as e: # pragma: no cover - defensive outer guard logging.warning(f"Failed to discover available connectors/document types: {e}") _perf_log.info( "[create_agent] Connector/doc-type discovery in %.3fs", @@ -469,29 +493,77 @@ async def create_surfsense_deep_agent( # entire middleware build + main-graph compile into a single # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # event loop stays responsive. + # + # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed + # on every per-request value that any middleware in the stack closes + # over in ``__init__`` — drop one and you risk leaking state across + # threads. Hits collapse this whole block to a microsecond lookup; + # misses pay the original CPU cost AND populate the cache. + config_id = agent_config.config_id if agent_config is not None else None + + async def _build_agent() -> Any: + return await asyncio.to_thread( + _build_compiled_agent_blocking, + llm=llm, + tools=tools, + final_system_prompt=final_system_prompt, + backend_resolver=backend_resolver, + filesystem_mode=filesystem_selection.mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + visibility=visibility, + anon_session_id=anon_session_id, + available_connectors=available_connectors, + available_document_types=available_document_types, + # ``mentioned_document_ids`` is consumed by + # ``KnowledgePriorityMiddleware`` per turn via + # ``runtime.context`` (Phase 1.5). We still pass the + # caller-provided list here for the legacy fallback path + # (cache disabled / context not propagated) — the middleware + # drains its own copy after the first read so a cached graph + # never replays stale mentions. + mentioned_document_ids=mentioned_document_ids, + max_input_tokens=_max_input_tokens, + flags=_flags, + checkpointer=checkpointer, + ) + _t0 = time.perf_counter() - agent = await asyncio.to_thread( - _build_compiled_agent_blocking, - llm=llm, - tools=tools, - final_system_prompt=final_system_prompt, - backend_resolver=backend_resolver, - filesystem_mode=filesystem_selection.mode, - search_space_id=search_space_id, - user_id=user_id, - thread_id=thread_id, - visibility=visibility, - anon_session_id=anon_session_id, - available_connectors=available_connectors, - available_document_types=available_document_types, - mentioned_document_ids=mentioned_document_ids, - max_input_tokens=_max_input_tokens, - flags=_flags, - checkpointer=checkpointer, - ) + if _flags.enable_agent_cache and not _flags.disable_new_agent_stack: + # Cache key components — order matters only for human readability; + # the resulting hash is what's stored. Every component must + # rotate on a real shape change AND stay stable across identical + # invocations. + cache_key = stable_hash( + "v1", # schema version of the key — bump if components change + config_id, + thread_id, + user_id, + search_space_id, + visibility, + filesystem_selection.mode, + anon_session_id, + tools_signature( + tools, + available_connectors=available_connectors, + available_document_types=available_document_types, + ), + flags_signature(_flags), + system_prompt_hash(final_system_prompt), + _max_input_tokens, + # ``mentioned_document_ids`` deliberately omitted — middleware + # reads it from ``runtime.context`` (Phase 1.5). + ) + agent = await get_cache().get_or_build(cache_key, builder=_build_agent) + else: + agent = await _build_agent() _perf_log.info( - "[create_agent] Middleware stack + graph compiled in %.3fs", + "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)", time.perf_counter() - _t0, + "on" + if _flags.enable_agent_cache and not _flags.disable_new_agent_stack + else "off", ) _perf_log.info( @@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking( noop_mw, retry_mw, fallback_mw, + # Coalesce a multi-text-block system message into one block + # immediately before the model call. Sits innermost on the + # system-message-mutation chain so it observes every appender + # (todo / filesystem / skills / subagents …) and prevents + # OpenRouter→Anthropic from redistributing ``cache_control`` + # across N blocks and tripping Anthropic's 4-breakpoint cap. + # See ``middleware/flatten_system.py`` for full rationale. + FlattenSystemMessageMiddleware(), # Tool-call repair must run after model emits but before # permission / dedup / doom-loop interpret the calls. repair_mw, diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py index c1fe45aaa..d720b524b 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/new_chat/context.py @@ -1,10 +1,25 @@ """ Context schema definitions for SurfSense agents. -This module defines the custom state schema used by the SurfSense deep agent. +This module defines the per-invocation context object passed to the SurfSense +deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6). + +The agent's compiled graph is the same across invocations (and cached by +``agent_cache``), so anything that varies per turn — the user mentions a +specific document, the front-end issues a unique ``request_id``, etc. — +MUST live on this context object instead of being captured into a +middleware ``__init__`` closure. Middlewares read fields back via +``runtime.context.``; tools read them via ``runtime.context``. + +This object is read inside both ``KnowledgePriorityMiddleware`` (for +``mentioned_document_ids``) and any future middleware that needs +per-request state without invalidating the compiled-agent cache. """ -from typing import NotRequired, TypedDict +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TypedDict class FileOperationContractState(TypedDict): @@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict): turn_id: str -class SurfSenseContextSchema(TypedDict): +@dataclass +class SurfSenseContextSchema: """ - Custom state schema for the SurfSense deep agent. + Per-invocation context for the SurfSense deep agent. - This extends the default agent state with custom fields. - The default state already includes: - - messages: Conversation history - - todos: Task list from TodoListMiddleware - - files: Virtual filesystem from FilesystemMiddleware + Defaults are chosen so the dataclass can be safely default-constructed + (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no + context is supplied — see ``langgraph.runtime.Runtime``). All fields + are optional; consumers must None-check before reading. - We're adding fields needed for knowledge base search: - - search_space_id: The user's search space ID - - db_session: Database session (injected at runtime) - - connector_service: Connector service instance (injected at runtime) + Phase 1.5 fields: + search_space_id: Search space the request is scoped to. + mentioned_document_ids: KB documents the user @-mentioned this turn. + Read by ``KnowledgePriorityMiddleware`` to seed its priority + list. Stays out of the compiled-agent cache key — that's the + whole point of putting it here. + file_operation_contract: One-shot file operation contract emitted + by ``FileIntentMiddleware`` for the upcoming turn. + turn_id / request_id: Correlation IDs surfaced by the streaming + task; populated for telemetry. + + Phase 2 will extend with: thread_id, user_id, visibility, + filesystem_mode, anon_session_id, available_connectors, + available_document_types, created_by_id (everything currently captured + by middleware ``__init__`` closures). """ - search_space_id: int - file_operation_contract: NotRequired[FileOperationContractState] - turn_id: NotRequired[str] - request_id: NotRequired[str] - # These are runtime-injected and won't be serialized - # db_session and connector_service are passed when invoking the agent + search_space_id: int | None = None + mentioned_document_ids: list[int] = field(default_factory=list) + file_operation_contract: FileOperationContractState | None = None + turn_id: str | None = None + request_id: str | None = None diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index f58bf0dd7..1f5a08ec6 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack. These flags gate the newer agent middleware (some ported from OpenCode, some sourced from ``langchain.agents.middleware`` / ``deepagents``, some -SurfSense-native). They follow a "default-OFF for risky things, -default-ON for safe upgrades, master kill-switch for everything new" model. +SurfSense-native). Most shipped agent-stack upgrades default ON so Docker +image updates work even when older installs do not have newly introduced +environment variables. Risky/experimental integrations stay default OFF, +and the master kill-switch can still disable everything new. All new middleware checks its flag at agent build time. If the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new @@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior. Examples -------- -Local development (recommended for trying everything except doom-loop / selector): +Defaults: SURFSENSE_ENABLE_CONTEXT_EDITING=true SURFSENSE_ENABLE_COMPACTION_V2=true SURFSENSE_ENABLE_RETRY_AFTER=true + SURFSENSE_ENABLE_MODEL_FALLBACK=false + SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true + SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true - SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy - SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships - SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false - SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events + SURFSENSE_ENABLE_PERMISSION=true + SURFSENSE_ENABLE_DOOM_LOOP=true + SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call + SURFSENSE_ENABLE_STREAM_PARITY_V2=true Master kill-switch (overrides everything else): @@ -60,32 +65,28 @@ class AgentFeatureFlags: disable_new_agent_stack: bool = False # Agent quality — context budget, retry/limits, name-repair, doom-loop - enable_context_editing: bool = False - enable_compaction_v2: bool = False - enable_retry_after: bool = False + enable_context_editing: bool = True + enable_compaction_v2: bool = True + enable_retry_after: bool = True enable_model_fallback: bool = False - enable_model_call_limit: bool = False - enable_tool_call_limit: bool = False - enable_tool_call_repair: bool = False - enable_doom_loop: bool = ( - False # Default OFF until UI handles permission='doom_loop' - ) + enable_model_call_limit: bool = True + enable_tool_call_limit: bool = True + enable_tool_call_repair: bool = True + enable_doom_loop: bool = True # Safety — permissions, concurrency, tool-set narrowing - enable_permission: bool = False # Default OFF for first deploy - enable_busy_mutex: bool = False + enable_permission: bool = True + enable_busy_mutex: bool = True enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost # Skills + subagents - enable_skills: bool = False - enable_specialized_subagents: bool = False - enable_kb_planner_runnable: bool = False + enable_skills: bool = True + enable_specialized_subagents: bool = True + enable_kb_planner_runnable: bool = True # Snapshot / revert - enable_action_log: bool = False - enable_revert_route: bool = ( - False # Backend ships before UI; route returns 503 until this flips - ) + enable_action_log: bool = True + enable_revert_route: bool = True # Streaming parity v2 — opt in to LangChain's structured # ``AIMessageChunk`` content (typed reasoning blocks, tool-input @@ -94,7 +95,7 @@ class AgentFeatureFlags: # text path and the synthetic ``call_`` tool-call id (no # ``langchainToolCallId`` propagation). Schema migrations 135/136 # ship unconditionally because they're forward-compatible. - enable_stream_parity_v2: bool = False + enable_stream_parity_v2: bool = True # Plugins enable_plugin_loader: bool = False @@ -102,6 +103,41 @@ class AgentFeatureFlags: # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) enable_otel: bool = False + # Performance — compiled-agent cache (Phase 1 + Phase 2). + # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled + # graph if the cache key matches (LLM config + thread + tool surface + + # flags + system prompt + filesystem mode). Cuts per-turn agent-build + # wall clock from ~4-5s to <50µs on cache hits. + # + # SAFETY (Phase 2 unblocked this default-on): + # All connector mutation tools (``tools/notion``, ``tools/gmail``, + # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``, + # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``, + # ``tools/teams``, ``tools/luma``, ``connected_accounts``, + # ``update_memory``, ``search_surfsense_docs``) now acquire fresh + # short-lived ``AsyncSession`` instances per call via + # :data:`async_session_maker`. The factory still accepts ``db_session`` + # for registry compatibility but ``del``'s it immediately — see any + # of those files' factory docstrings for the rationale. The ``llm`` + # closure is per-(provider, model, config_id) which is already in + # the cache key, so the LLM is safe to share across cached hits of + # the same key. The KB priority middleware reads + # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5), + # not its constructor closure, so the same compiled agent serves + # turns with different mention lists correctly. + # + # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the + # environment if a regression surfaces. The path is exercised by + # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite. + enable_agent_cache: bool = True + # Phase 1 (deferred — measure first): pre-build & share the + # general-purpose subagent ``CompiledSubAgent`` across cold-cache + # misses. Only helps when the outer cache MISSES (cache hits already + # reuse the entire SubAgentMiddleware-compiled graph). Off by default + # until we have data showing cold misses are frequent enough to + # justify the extra global state. + enable_agent_cache_share_gp_subagent: bool = False + @classmethod def from_env(cls) -> AgentFeatureFlags: """Read flags from environment. @@ -115,48 +151,76 @@ class AgentFeatureFlags: "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " "middleware is forced OFF for this build." ) - return cls(disable_new_agent_stack=True) + return cls( + disable_new_agent_stack=True, + enable_context_editing=False, + enable_compaction_v2=False, + enable_retry_after=False, + enable_model_fallback=False, + enable_model_call_limit=False, + enable_tool_call_limit=False, + enable_tool_call_repair=False, + enable_doom_loop=False, + enable_permission=False, + enable_busy_mutex=False, + enable_llm_tool_selector=False, + enable_skills=False, + enable_specialized_subagents=False, + enable_kb_planner_runnable=False, + enable_action_log=False, + enable_revert_route=False, + enable_stream_parity_v2=False, + enable_plugin_loader=False, + enable_otel=False, + enable_agent_cache=False, + enable_agent_cache_share_gp_subagent=False, + ) return cls( disable_new_agent_stack=False, # Agent quality - enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), - enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), - enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), + enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True), + enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True), + enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), enable_model_call_limit=_env_bool( - "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True ), - enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), + enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True), enable_tool_call_repair=_env_bool( - "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True ), - enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), + enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True), # Safety - enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), - enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), + enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True), + enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True), enable_llm_tool_selector=_env_bool( "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False ), # Skills + subagents - enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), + enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True), enable_specialized_subagents=_env_bool( - "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True ), enable_kb_planner_runnable=_env_bool( - "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True ), # Snapshot / revert - enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), - enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True), + enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True), # Streaming parity v2 enable_stream_parity_v2=_env_bool( - "SURFSENSE_ENABLE_STREAM_PARITY_V2", False + "SURFSENSE_ENABLE_STREAM_PARITY_V2", True ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), + # Performance + enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True), + enable_agent_cache_share_gp_subagent=_env_bool( + "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False + ), ) def any_new_middleware_enabled(self) -> bool: diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 99bb719f6..bc37bf1c4 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "XAI": "xai", - "BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "COMETAPI": "cometapi", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} +# Provider mapping for LiteLLM model string construction. +# +# Single source of truth lives in +# :mod:`app.services.provider_capabilities` so the YAML loader (which +# runs during ``app.config`` class-body init) can resolve provider +# prefixes without dragging the agent / tools tree into module load +# order. Re-exported here under the historical ``PROVIDER_MAP`` name +# so existing callers (``llm_router_service``, ``image_gen_router_service``, +# tests) keep working unchanged. +from app.services.provider_capabilities import ( # noqa: E402 + _PROVIDER_PREFIX_MAP as PROVIDER_MAP, +) def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: @@ -178,6 +155,17 @@ class AgentConfig: anonymous_enabled: bool = False quota_reserve_tokens: int | None = None + # Capability flag: best-effort True for the chat selector / catalog. + # Resolved via :func:`provider_capabilities.derive_supports_image_input` + # which prefers OpenRouter's ``architecture.input_modalities`` and + # otherwise consults LiteLLM's authoritative model map. Default True + # is the conservative-allow stance — the streaming-task safety net + # (``is_known_text_only_chat_model``) is the *only* place a False + # actually blocks a request. Setting this to False here without an + # authoritative source would silently hide vision-capable models + # (the regression we're fixing). + supports_image_input: bool = True + @classmethod def from_auto_mode(cls) -> "AgentConfig": """ @@ -203,6 +191,12 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # Auto routes across the configured pool, which usually + # contains at least one vision-capable deployment; the router + # will surface a 404 from a non-vision deployment as a normal + # ``allowed_fails`` event and fail over rather than blocking + # the request outright. + supports_image_input=True, ) @classmethod @@ -216,10 +210,24 @@ class AgentConfig: Returns: AgentConfig instance """ - return cls( - provider=config.provider.value + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + + provider_value = ( + config.provider.value if hasattr(config.provider, "value") - else str(config.provider), + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + return cls( + provider=provider_value, model_name=config.model_name, api_key=config.api_key, api_base=config.api_base, @@ -235,6 +243,16 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # BYOK rows have no operator-curated capability flag, so we + # ask LiteLLM (default-allow on unknown). The streaming + # safety net still blocks if the model is *explicitly* + # marked text-only. + supports_image_input=derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ), ) @classmethod @@ -253,15 +271,46 @@ class AgentConfig: Returns: AgentConfig instance """ + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + # Get system instructions from YAML, default to empty string system_instructions = yaml_config.get("system_instructions", "") + provider = yaml_config.get("provider", "").upper() + model_name = yaml_config.get("model_name", "") + custom_provider = yaml_config.get("custom_provider") + litellm_params = yaml_config.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + # Explicit YAML override wins; otherwise derive from LiteLLM / + # OpenRouter modalities. The YAML loader already populates this + # field, but this method is also called from + # ``load_global_llm_config_by_id``'s file fallback (hot reload), + # so we re-derive here for safety. The bool() coercion preserves + # the loader's behaviour for explicit ``true`` / ``false`` + # strings that PyYAML may surface. + if "supports_image_input" in yaml_config: + supports_image_input = bool(yaml_config.get("supports_image_input")) + else: + supports_image_input = derive_supports_image_input( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ) + return cls( - provider=yaml_config.get("provider", "").upper(), - model_name=yaml_config.get("model_name", ""), + provider=provider, + model_name=model_name, api_key=yaml_config.get("api_key", ""), api_base=yaml_config.get("api_base"), - custom_provider=yaml_config.get("custom_provider"), + custom_provider=custom_provider, litellm_params=yaml_config.get("litellm_params"), # Prompt configuration from YAML (with defaults for backwards compatibility) system_instructions=system_instructions if system_instructions else None, @@ -276,6 +325,7 @@ class AgentConfig: is_premium=yaml_config.get("billing_tier", "free") == "premium", anonymous_enabled=yaml_config.get("anonymous_enabled", False), quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"), + supports_image_input=supports_image_input, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 094c102f8..6742bd8de 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import ( from app.agents.new_chat.middleware.filesystem import ( SurfSenseFilesystemMiddleware, ) +from app.agents.new_chat.middleware.flatten_system import ( + FlattenSystemMessageMiddleware, +) from app.agents.new_chat.middleware.kb_persistence import ( KnowledgeBasePersistenceMiddleware, commit_staged_filesystem_state, @@ -61,6 +64,7 @@ __all__ = [ "DedupHITLToolCallsMiddleware", "DoomLoopMiddleware", "FileIntentMiddleware", + "FlattenSystemMessageMiddleware", "KnowledgeBasePersistenceMiddleware", "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", diff --git a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py new file mode 100644 index 000000000..29cd57aa0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py @@ -0,0 +1,233 @@ +r"""Coalesce multi-block system messages into a single text block. + +Several middlewares in our deepagent stack each call +``append_to_system_message`` on the way down to the model +(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``, +``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the +request reaches the LLM, the system message has 5+ separate text blocks. + +Anthropic enforces a hard cap of **4 ``cache_control`` blocks per +request**, and we configure 2 injection points +(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting +the prepended ``request.system_message``, this middleware is the +defensive partner: it guarantees that "the system block" is *one* +content block, so LiteLLM's ``AnthropicCacheControlHook`` and any +OpenRouter→Anthropic transformer can never multiply our budget into +several breakpoints by spreading ``cache_control`` across multiple +text blocks of a multi-block system content. + +Without flattening we used to see:: + + OpenrouterException - {"error":{"message":"Provider returned error", + "code":400,"metadata":{"raw":"...A maximum of 4 blocks with + cache_control may be provided. Found 5."}}} + +(Same error class documented in +https://github.com/BerriAI/litellm/issues/15696 and +https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix +in PR #15395 covers the litellm transformer but does not protect us +when the OpenRouter SaaS itself does the redistribution.) + +A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching +the first injection point from ``role: system`` to ``index: 0``) +neutralises the *primary* cause of the same 400 — multiple +``SystemMessage``\ s injected by ``before_agent`` middlewares +(priority/tree/memory/file-intent/anonymous-doc) accumulating across +turns, each tagged with ``cache_control`` by the ``role: system`` +matcher. This middleware remains useful as defence-in-depth against +the multi-block redistribution path. + +Placement: innermost on the system-message-mutation chain, after every +appender (``todo``/``filesystem``/``skills``/``subagents``) and after +summarization, but before ``noop``/``retry``/``fallback`` so each retry +attempt sees a flattened payload. See ``chat_deepagent.py``. + +Idempotent: a string-content system message is left untouched. A list +that contains anything other than plain text blocks (e.g. an image) is +also left untouched — those are rare on system messages and we'd lose +the non-text payload by joining. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.messages import SystemMessage + +logger = logging.getLogger(__name__) + + +def _flatten_text_blocks(content: list[Any]) -> str | None: + """Return joined text if every block is a plain ``{"type": "text"}``. + + Returns ``None`` when the list contains anything that isn't a text + block we can safely concatenate (image, audio, file, non-standard + blocks, dicts with extra non-cache_control fields). The caller + leaves the original content untouched in that case rather than + silently dropping payload. + + ``cache_control`` on individual blocks is intentionally discarded — + the whole point of flattening is to let LiteLLM's + ``cache_control_injection_points`` re-place a single breakpoint on + the resulting one-block system content. + """ + chunks: list[str] = [] + for block in content: + if isinstance(block, str): + chunks.append(block) + continue + if not isinstance(block, dict): + return None + if block.get("type") != "text": + return None + text = block.get("text") + if not isinstance(text, str): + return None + chunks.append(text) + return "\n\n".join(chunks) + + +def _flattened_request( + request: ModelRequest[ContextT], +) -> ModelRequest[ContextT] | None: + """Return a request with system_message flattened, or ``None`` for no-op.""" + sys_msg = request.system_message + if sys_msg is None: + return None + content = sys_msg.content + if not isinstance(content, list) or len(content) <= 1: + return None + + flattened = _flatten_text_blocks(content) + if flattened is None: + return None + + new_sys = SystemMessage( + content=flattened, + additional_kwargs=dict(sys_msg.additional_kwargs), + response_metadata=dict(sys_msg.response_metadata), + ) + if sys_msg.id is not None: + new_sys.id = sys_msg.id + return request.override(system_message=new_sys) + + +def _diagnostic_summary(request: ModelRequest[Any]) -> str: + """One-line dump of cache_control-relevant request shape. + + Temporary diagnostic to prove where the ``Found N`` cache_control + breakpoints are coming from when Anthropic 400s. Removed once the + root cause is confirmed and a fix is in place. + """ + sys_msg = request.system_message + if sys_msg is None: + sys_shape = "none" + elif isinstance(sys_msg.content, str): + sys_shape = f"str(len={len(sys_msg.content)})" + elif isinstance(sys_msg.content, list): + sys_shape = f"list(blocks={len(sys_msg.content)})" + else: + sys_shape = f"other({type(sys_msg.content).__name__})" + + role_hist: list[str] = [] + multi_block_msgs = 0 + msgs_with_cc = 0 + sys_msgs_in_history = 0 + for m in request.messages: + mtype = getattr(m, "type", type(m).__name__) + role_hist.append(mtype) + if isinstance(m, SystemMessage): + sys_msgs_in_history += 1 + c = getattr(m, "content", None) + if isinstance(c, list): + multi_block_msgs += 1 + for blk in c: + if isinstance(blk, dict) and "cache_control" in blk: + msgs_with_cc += 1 + break + if "cache_control" in getattr(m, "additional_kwargs", {}) or {}: + msgs_with_cc += 1 + + tools = request.tools or [] + tools_with_cc = 0 + for t in tools: + if isinstance(t, dict) and ( + "cache_control" in t or "cache_control" in t.get("function", {}) + ): + tools_with_cc += 1 + + return ( + f"sys={sys_shape} msgs={len(request.messages)} " + f"sys_msgs_in_history={sys_msgs_in_history} " + f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} " + f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} " + f"roles={role_hist[-8:]}" + ) + + +class FlattenSystemMessageMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): + """Collapse a multi-text-block system message to a single string. + + Sits innermost on the system-message-mutation chain so it observes + every middleware's contribution. Has no other side effect — the + body of every block is preserved, just joined with ``"\\n\\n"``. + """ + + def __init__(self) -> None: + super().__init__() + self.tools = [] + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> Any: + if logger.isEnabledFor(logging.DEBUG): + logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) + flattened = _flattened_request(request) + if flattened is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[flatten_system] collapsed %d system blocks to one", + len(request.system_message.content), # type: ignore[arg-type, union-attr] + ) + return handler(flattened) + return handler(request) + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> Any: + if logger.isEnabledFor(logging.DEBUG): + logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) + flattened = _flattened_request(request) + if flattened is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[flatten_system] collapsed %d system blocks to one", + len(request.system_message.content), # type: ignore[arg-type, union-attr] + ) + return await handler(flattened) + return await handler(request) + + +__all__ = [ + "FlattenSystemMessageMiddleware", + "_flatten_text_blocks", + "_flattened_request", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 0820e8c3e..ee5c1d182 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] state: AgentState, runtime: Runtime[Any], ) -> dict[str, Any] | None: - del runtime if self.filesystem_mode != FilesystemMode.CLOUD: return None @@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] if anon_doc: return self._anon_priority(state, anon_doc) - return await self._authenticated_priority(state, messages, user_text) + return await self._authenticated_priority(state, messages, user_text, runtime) def _anon_priority( self, @@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] state: AgentState, messages: Sequence[BaseMessage], user_text: str, + runtime: Runtime[Any] | None = None, ) -> dict[str, Any]: t0 = asyncio.get_event_loop().time() ( @@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) + # Per-turn ``mentioned_document_ids`` flow: + # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the + # streaming task supplies a fresh :class:`SurfSenseContextSchema` + # on every ``astream_events`` call, so this list is naturally + # scoped to the current turn. Allows cross-turn graph reuse via + # ``agent_cache``. + # 2. Legacy fallback (cache disabled / context not propagated): the + # constructor-injected ``self.mentioned_document_ids`` list. We + # drain it after the first read so a cached graph (no Phase 1.5 + # wiring) doesn't keep replaying the same mentions on every + # turn. + # + # CRITICAL: distinguish "context absent" (legacy caller, no field at + # all) from "context provided but empty" (turn with no mentions). + # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in + # Python, so a naive ``if ctx_mentions:`` would fall through to the + # legacy closure on every no-mention follow-up turn — replaying the + # mentions baked in by turn 1's cache-miss build. Always drain the + # closure once the runtime path has fired so a cached middleware + # instance can never resurrect stale state. + mention_ids: list[int] = [] + ctx = getattr(runtime, "context", None) if runtime is not None else None + ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None + if ctx_mentions is not None: + # Runtime path is authoritative — even an empty list means + # "this turn has no mentions", NOT "look at the closure". + mention_ids = list(ctx_mentions) + if self.mentioned_document_ids: + self.mentioned_document_ids = [] + elif self.mentioned_document_ids: + mention_ids = list(self.mentioned_document_ids) + self.mentioned_document_ids = [] + mentioned_results: list[dict[str, Any]] = [] - if self.mentioned_document_ids: + if mention_ids: mentioned_results = await fetch_mentioned_documents( - document_ids=self.mentioned_document_ids, + document_ids=mention_ids, search_space_id=self.search_space_id, ) - self.mentioned_document_ids = [] if is_recency: doc_types = _resolve_search_types( diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py index 86bc57725..9fe47cdac 100644 --- a/surfsense_backend/app/agents/new_chat/prompt_caching.py +++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py @@ -1,4 +1,4 @@ -"""LiteLLM-native prompt caching configuration for SurfSense agents. +r"""LiteLLM-native prompt caching configuration for SurfSense agents. Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` @@ -17,8 +17,20 @@ Coverage: We inject **two** breakpoints per request: -- ``role: system`` — pins the SurfSense system prompt (provider variant, - citation rules, tool catalog, KB tree, skills metadata) into the cache. +- ``index: 0`` — pins the SurfSense system prompt at the head of the + request (provider variant, citation rules, tool catalog, KB tree, + skills metadata). The langchain agent factory always prepends + ``request.system_message`` at index 0 (see ``factory.py`` + ``_execute_model_async``), so this targets exactly the main system + prompt regardless of how many other ``SystemMessage``\ s the + ``before_agent`` injectors (priority, tree, memory, file-intent, + anonymous-doc) have inserted into ``state["messages"]``. Using + ``role: system`` here would apply ``cache_control`` to **every** + system-role message and trip Anthropic's hard cap of 4 cache + breakpoints per request once the conversation accumulates enough + injected system messages — which surfaces as the upstream 400 + ``A maximum of 4 blocks with cache_control may be provided. Found N`` + via OpenRouter→Anthropic. - ``index: -1`` — pins the latest message so multi-turn savings compound: Anthropic-family providers use longest-matching-prefix lookup, so turn N+1 still reads turn N's cache up to the shared prefix. @@ -51,11 +63,21 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Two-breakpoint policy: system + latest message. See module docstring for -# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we -# use 2 here, leaving headroom for Phase-2 tool caching. +# Two-breakpoint policy: head-of-request + latest message. See module +# docstring for rationale. Anthropic caps requests at 4 ``cache_control`` +# blocks; we use 2 here, leaving headroom for Phase-2 tool caching. +# +# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's +# ``before_agent`` middlewares (priority, tree, memory, file-intent, +# anonymous-doc) insert ``SystemMessage`` instances into +# ``state["messages"]`` that accumulate across turns. With +# ``role: system`` the LiteLLM hook would tag *every* one of them with +# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0`` +# always targets the langchain-prepended ``request.system_message`` +# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text +# block), giving us exactly one stable cache breakpoint. _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( - {"location": "message", "role": "system"}, + {"location": "message", "index": 0}, {"location": "message", "index": -1}, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py index 095413bdb..c56db1528 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_create_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the create_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def create_confluence_page( title: str, @@ -42,160 +60,163 @@ def create_create_confluence_page_tool( """ logger.info(f"create_confluence_page called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - return {"status": "error", "message": context["error"]} + if "error" in context: + return {"status": "error", "message": context["error"]} - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Confluence accounts need re-authentication.", - "connector_type": "confluence", - } + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Confluence accounts need re-authentication.", + "connector_type": "confluence", + } - result = request_approval( - action_type="confluence_page_creation", - tool_name="create_confluence_page", - params={ - "title": title, - "content": content, - "space_id": space_id, - "connector_id": connector_id, - }, - context=context, - ) + result = request_approval( + action_type="confluence_page_creation", + tool_name="create_confluence_page", + params={ + "title": title, + "content": content, + "space_id": space_id, + "connector_id": connector_id, + }, + context=context, + ) - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) or "" - final_space_id = result.params.get("space_id", space_id) - final_connector_id = result.params.get("connector_id", connector_id) + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) or "" + final_space_id = result.params.get("space_id", space_id) + final_connector_id = result.params.get("connector_id", connector_id) - if not final_title or not final_title.strip(): - return {"status": "error", "message": "Page title cannot be empty."} - if not final_space_id: - return {"status": "error", "message": "A space must be selected."} + if not final_title or not final_title.strip(): + return {"status": "error", "message": "Page title cannot be empty."} + if not final_space_id: + return {"status": "error", "message": "A space must be selected."} - from sqlalchemy.future import select + from sqlalchemy.future import select - from app.db import SearchSourceConnector, SearchSourceConnectorType + from app.db import SearchSourceConnector, SearchSourceConnectorType - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Confluence connector found.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - api_result = await client.create_page( - space_id=final_space_id, - title=final_title, - body=final_content, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - page_id = str(api_result.get("id", "")) - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=page_id, - page_title=final_title, - space_id=final_space_id, - body_content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Confluence connector found.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } - return { - "status": "success", - "page_id": page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", - } + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=actual_connector_id + ) + api_result = await client.create_page( + space_id=final_space_id, + title=final_title, + body=final_content, + ) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + _conn = connector + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + page_id = str(api_result.get("id", "")) + page_links = ( + api_result.get("_links", {}) if isinstance(api_result, dict) else {} + ) + page_url = "" + if page_links.get("base") and page_links.get("webui"): + page_url = f"{page_links['base']}{page_links['webui']}" + + kb_message_suffix = "" + try: + from app.services.confluence import ConfluenceKBSyncService + + kb_service = ConfluenceKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + page_id=page_id, + page_title=final_title, + space_id=final_space_id, + body_content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "page_id": page_id, + "page_url": page_url, + "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py index 7c03c2760..d4cd5032f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_delete_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the delete_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def delete_confluence_page( page_title_or_id: str, @@ -43,137 +61,143 @@ def create_delete_confluence_page_tool( f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, page_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - page_data = context["page"] - page_id = page_data["page_id"] - page_title = page_data.get("page_title", "") - document_id = page_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_deletion", - tool_name="delete_confluence_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, page_title_or_id ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "confluence", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + page_data = context["page"] + page_id = page_data["page_id"] + page_title = page_data.get("page_title", "") + document_id = page_data["document_id"] + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="confluence_page_deletion", + tool_name="delete_confluence_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - await client.delete_page(final_page_id) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this page.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } - message = f"Confluence page '{page_title}' deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + await client.delete_page(final_page_id) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - return { - "status": "success", - "page_id": final_page_id, - "deleted_from_kb": deleted_from_kb, - "message": message, - } + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + + message = f"Confluence page '{page_title}' deleted successfully." + if deleted_from_kb: + message += " Also removed from the knowledge base." + + return { + "status": "success", + "page_id": final_page_id, + "deleted_from_kb": deleted_from_kb, + "message": message, + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py index 791d0d8c5..51c205e00 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_update_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the update_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def update_confluence_page( page_title_or_id: str, @@ -45,164 +63,168 @@ def create_update_confluence_page_tool( f"update_confluence_page called: page_title_or_id='{page_title_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title_or_id - ) + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, page_title_or_id + ) - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "confluence", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + page_data = context["page"] + page_id = page_data["page_id"] + current_title = page_data["page_title"] + current_body = page_data.get("body", "") + current_version = page_data.get("version", 1) + document_id = page_data.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="confluence_page_update", + tool_name="update_confluence_page", + params={ + "page_id": page_id, + "document_id": document_id, + "new_title": new_title, + "new_content": new_content, + "version": current_version, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - page_data = context["page"] - page_id = page_data["page_id"] - current_title = page_data["page_title"] - current_body = page_data.get("body", "") - current_version = page_data.get("version", 1) - document_id = page_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_update", - tool_name="update_confluence_page", - params={ - "page_id": page_id, - "document_id": document_id, - "new_title": new_title, - "new_content": new_content, - "version": current_version, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_title = result.params.get("new_title", new_title) or current_title - final_content = result.params.get("new_content", new_content) - if final_content is None: - final_content = current_body - final_version = result.params.get("version", current_version) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + final_page_id = result.params.get("page_id", page_id) + final_title = result.params.get("new_title", new_title) or current_title + final_content = result.params.get("new_content", new_content) + if final_content is None: + final_content = current_body + final_version = result.params.get("version", current_version) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } + final_document_id = result.params.get("document_id", document_id) - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id - ) - api_result = await client.update_page( - page_id=final_page_id, - title=final_title, - body=final_content, - version_number=final_version + 1, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "No connector found for this page.", } - raise - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - if final_document_id: - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - page_id=final_page_id, - user_id=user_id, - search_space_id=search_space_id, + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } + + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + api_result = await client.update_page( + page_id=final_page_id, + title=final_title, + body=final_content, + version_number=final_version + 1, + ) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + page_links = ( + api_result.get("_links", {}) if isinstance(api_result, dict) else {} + ) + page_url = "" + if page_links.get("base") and page_links.get("webui"): + page_url = f"{page_links['base']}{page_links['webui']}" + + kb_message_suffix = "" + if final_document_id: + try: + from app.services.confluence import ConfluenceKBSyncService + + kb_service = ConfluenceKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + page_id=final_page_id, + user_id=user_id, + search_space_id=search_space_id, ) - else: + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = ( + " The knowledge base will be updated in the next sync." + ) + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") kb_message_suffix = ( " The knowledge base will be updated in the next sync." ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - return { - "status": "success", - "page_id": final_page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", - } + return { + "status": "success", + "page_id": final_page_id, + "page_url": page_url, + "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py index 5675a42e6..6420a90e6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py +++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker from app.services.mcp_oauth.registry import MCP_SERVICES logger = logging.getLogger(__name__) @@ -53,6 +53,23 @@ def create_get_connected_accounts_tool( search_space_id: int, user_id: str, ) -> StructuredTool: + """Factory function to create the get_connected_accounts tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to scope account discovery to. + user_id: User ID to scope account discovery to. + + Returns: + Configured StructuredTool for connected-accounts discovery. + """ + del db_session # per-call session — see docstring async def _run(service: str) -> list[dict[str, Any]]: svc_cfg = MCP_SERVICES.get(service) @@ -68,40 +85,41 @@ def create_get_connected_accounts_tool( except ValueError: return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == connector_type, + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == connector_type, + ) ) - ) - connectors = result.scalars().all() + connectors = result.scalars().all() - if not connectors: - return [ - { - "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + if not connectors: + return [ + { + "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + } + ] + + is_multi = len(connectors) > 1 + + accounts: list[dict[str, Any]] = [] + for conn in connectors: + cfg = conn.config or {} + entry: dict[str, Any] = { + "connector_id": conn.id, + "display_name": _extract_display_name(conn), + "service": service, } - ] + if is_multi: + entry["tool_prefix"] = f"{service}_{conn.id}" + for key in svc_cfg.account_metadata_keys: + if key in cfg: + entry[key] = cfg[key] + accounts.append(entry) - is_multi = len(connectors) > 1 - - accounts: list[dict[str, Any]] = [] - for conn in connectors: - cfg = conn.config or {} - entry: dict[str, Any] = { - "connector_id": conn.id, - "display_name": _extract_display_name(conn), - "service": service, - } - if is_multi: - entry["tool_prefix"] = f"{service}_{conn.id}" - for key in svc_cfg.account_metadata_keys: - if key in cfg: - entry[key] = cfg[key] - accounts.append(entry) - - return accounts + return accounts return StructuredTool( name="get_connected_accounts", diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py index 3cc99ac17..01159a261 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_discord_channels_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_discord_channels tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_discord_channels tool + """ + del db_session # per-call session — see docstring + @tool async def list_discord_channels() -> dict[str, Any]: """List text channels in the connected Discord server. @@ -22,59 +41,60 @@ def create_list_discord_channels_tool( Returns: Dictionary with status and a list of channels (id, name). """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", } try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - guild_id = get_guild_id(connector) - if not guild_id: - return { - "status": "error", - "message": "No guild ID in Discord connector config.", - } - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/guilds/{guild_id}/channels", - headers={"Authorization": f"Bot {token}"}, - timeout=15.0, + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + guild_id = get_guild_id(connector) + if not guild_id: + return { + "status": "error", + "message": "No guild ID in Discord connector config.", + } - # Type 0 = text channel - channels = [ - {"id": ch["id"], "name": ch["name"]} - for ch in resp.json() - if ch.get("type") == 0 - ] - return { - "status": "success", - "guild_id": guild_id, - "channels": channels, - "total": len(channels), - } + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/guilds/{guild_id}/channels", + headers={"Authorization": f"Bot {token}"}, + timeout=15.0, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + # Type 0 = text channel + channels = [ + {"id": ch["id"], "name": ch["name"]} + for ch in resp.json() + if ch.get("type") == 0 + ] + return { + "status": "success", + "guild_id": guild_id, + "channels": channels, + "total": len(channels), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py index d8bf989a1..88d6cdd49 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import DISCORD_API, get_bot_token, get_discord_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_discord_messages_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_discord_messages tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_discord_messages tool + """ + del db_session # per-call session — see docstring + @tool async def read_discord_messages( channel_id: str, @@ -30,7 +49,7 @@ def create_read_discord_messages_tool( Dictionary with status and a list of messages including id, author, content, timestamp. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", @@ -39,55 +58,56 @@ def create_read_discord_messages_tool( limit = min(limit, 50) try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/channels/{channel_id}/messages", - headers={"Authorization": f"Bot {token}"}, - params={"limit": limit}, - timeout=15.0, + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + token = get_bot_token(connector) - messages = [ - { - "id": m["id"], - "author": m.get("author", {}).get("username", "Unknown"), - "content": m.get("content", ""), - "timestamp": m.get("timestamp", ""), - } - for m in resp.json() - ] + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/channels/{channel_id}/messages", + headers={"Authorization": f"Bot {token}"}, + params={"limit": limit}, + timeout=15.0, + ) - return { - "status": "success", - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Bot lacks permission to read this channel.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + messages = [ + { + "id": m["id"], + "author": m.get("author", {}).get("username", "Unknown"), + "content": m.get("content", ""), + "timestamp": m.get("timestamp", ""), + } + for m in resp.json() + ] + + return { + "status": "success", + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py index 236cd017a..5fe6fde35 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import DISCORD_API, get_bot_token, get_discord_connector @@ -17,6 +18,23 @@ def create_send_discord_message_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_discord_message tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_discord_message tool + """ + del db_session # per-call session — see docstring + @tool async def send_discord_message( channel_id: str, @@ -34,7 +52,7 @@ def create_send_discord_message_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", @@ -47,64 +65,65 @@ def create_send_discord_message_tool( } try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - result = request_approval( - action_type="discord_send_message", - tool_name="send_discord_message", - params={"channel_id": channel_id, "content": content}, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_channel = result.params.get("channel_id", channel_id) - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{DISCORD_API}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bot {token}", - "Content-Type": "application/json", - }, - json={"content": final_content}, - timeout=15.0, + result = request_approval( + action_type="discord_send_message", + tool_name="send_discord_message", + params={"channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to send messages in this channel.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": f"Message sent to channel {final_channel}.", - } + final_content = result.params.get("content", content) + final_channel = result.params.get("channel_id", channel_id) + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{DISCORD_API}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bot {token}", + "Content-Type": "application/json", + }, + json={"content": final_content}, + timeout=15.0, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Bot lacks permission to send messages in this channel.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to channel {final_channel}.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py index 22d8a8a27..7aae034cc 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py @@ -10,7 +10,7 @@ from sqlalchemy.future import select from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -59,6 +59,23 @@ def create_create_dropbox_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_dropbox_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_dropbox_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_dropbox_file( name: str, @@ -82,184 +99,191 @@ def create_create_dropbox_file_tool( f"create_dropbox_file called: name='{name}', file_type='{file_type}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Dropbox tool not properly configured.", } try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Dropbox accounts need re-authentication.", - "connector_type": "dropbox", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = DropboxClient(session=db_session, connector_id=cid) - items, err = await client.list_folder("") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - { - "folder_path": item.get("path_lower", ""), - "name": item["name"], - } - for item in items - if item.get(".tag") == "folder" and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", cid, exc_info=True - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - "supported_types": _SUPPORTED_TYPES, - } - - result = request_approval( - action_type="dropbox_file_creation", - tool_name="create_dropbox_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_path": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_path = result.params.get("parent_folder_path") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_extension(final_name, final_file_type) - - if final_connector_id is not None: + async with async_session_maker() as db_session: result = await db_session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type == SearchSourceConnectorType.DROPBOX_CONNECTOR, ) ) - connector = result.scalars().first() - else: - connector = connectors[0] + connectors = result.scalars().all() - if not connector: - return { - "status": "error", - "message": "Selected Dropbox connector is invalid.", + if not connectors: + return { + "status": "error", + "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", + } + + accounts = [] + for c in connectors: + cfg = c.config or {} + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) + + if all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Dropbox accounts need re-authentication.", + "connector_type": "dropbox", + } + + parent_folders: dict[int, list[dict[str, str]]] = {} + for acc in accounts: + cid = acc["id"] + if acc.get("auth_expired"): + parent_folders[cid] = [] + continue + try: + client = DropboxClient(session=db_session, connector_id=cid) + items, err = await client.list_folder("") + if err: + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) + parent_folders[cid] = [] + else: + parent_folders[cid] = [ + { + "folder_path": item.get("path_lower", ""), + "name": item["name"], + } + for item in items + if item.get(".tag") == "folder" and item.get("name") + ] + except Exception: + logger.warning( + "Error fetching folders for connector %s", + cid, + exc_info=True, + ) + parent_folders[cid] = [] + + context: dict[str, Any] = { + "accounts": accounts, + "parent_folders": parent_folders, + "supported_types": _SUPPORTED_TYPES, } - client = DropboxClient(session=db_session, connector_id=connector.id) - - parent_path = final_parent_folder_path or "" - file_path = ( - f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" - ) - - if final_file_type == "paper": - created = await client.create_paper_doc(file_path, final_content or "") - file_id = created.get("file_id", "") - web_url = created.get("url", "") - else: - docx_bytes = _markdown_to_docx(final_content or "") - created = await client.upload_file( - file_path, docx_bytes, mode="add", autorename=True + result = request_approval( + action_type="dropbox_file_creation", + tool_name="create_dropbox_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_path": None, + }, + context=context, ) - file_id = created.get("id", "") - web_url = "" - logger.info(f"Dropbox file created: id={file_id}, name={final_name}") + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - kb_message_suffix = "" - try: - from app.services.dropbox import DropboxKBSyncService + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_path = result.params.get("parent_folder_path") - kb_service = DropboxKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=file_id, - file_name=final_name, - file_path=file_path, - web_url=web_url, - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + final_name = _ensure_extension(final_name, final_file_type) + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DROPBOX_CONNECTOR, + ) + ) + connector = result.scalars().first() else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + connector = connectors[0] - return { - "status": "success", - "file_id": file_id, - "name": final_name, - "web_url": web_url, - "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", - } + if not connector: + return { + "status": "error", + "message": "Selected Dropbox connector is invalid.", + } + + client = DropboxClient(session=db_session, connector_id=connector.id) + + parent_path = final_parent_folder_path or "" + file_path = ( + f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" + ) + + if final_file_type == "paper": + created = await client.create_paper_doc( + file_path, final_content or "" + ) + file_id = created.get("file_id", "") + web_url = created.get("url", "") + else: + docx_bytes = _markdown_to_docx(final_content or "") + created = await client.upload_file( + file_path, docx_bytes, mode="add", autorename=True + ) + file_id = created.get("id", "") + web_url = "" + + logger.info(f"Dropbox file created: id={file_id}, name={final_name}") + + kb_message_suffix = "" + try: + from app.services.dropbox import DropboxKBSyncService + + kb_service = DropboxKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=file_id, + file_name=final_name, + file_path=file_path, + web_url=web_url, + content=final_content, + connector_id=connector.id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": file_id, + "name": final_name, + "web_url": web_url, + "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py index 12559b57a..0e59e49db 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py @@ -13,6 +13,7 @@ from app.db import ( DocumentType, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) logger = logging.getLogger(__name__) @@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_dropbox_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_dropbox_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_dropbox_file( file_name: str, @@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool( f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Dropbox tool not properly configured.", } try: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.DROPBOX_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: + async with async_session_maker() as db_session: doc_result = await db_session.execute( select(Document) .join( @@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool( and_( Document.search_space_id == search_space_id, Document.document_type == DocumentType.DROPBOX_FILE, - func.lower( - cast( - Document.document_metadata["dropbox_file_name"], - String, - ) - ) - == func.lower(file_name), + func.lower(Document.title) == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool( ) document = doc_result.scalars().first() - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed Dropbox files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_path = meta.get("dropbox_path") - file_id = meta.get("dropbox_file_id") - document_id = document.id - - if not file_path: - return { - "status": "error", - "message": "File path is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, + if not document: + doc_result = await db_session.execute( + select(Document) + .join( + SearchSourceConnector, + Document.connector_id == SearchSourceConnector.id, + ) + .filter( + and_( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.DROPBOX_FILE, + func.lower( + cast( + Document.document_metadata["dropbox_file_name"], + String, + ) + ) + == func.lower(file_name), + SearchSourceConnector.user_id == user_id, + ) + ) + .order_by(Document.updated_at.desc().nullslast()) + .limit(1) ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Dropbox connector not found or access denied.", - } + document = doc_result.scalars().first() - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "dropbox", - } + if not document: + return { + "status": "not_found", + "message": ( + f"File '{file_name}' not found in your indexed Dropbox files. " + "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " + "or (3) the file name is different." + ), + } - context = { - "file": { - "file_id": file_id, - "file_path": file_path, - "name": file_name, - "document_id": document_id, - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } + if not document.connector_id: + return { + "status": "error", + "message": "Document has no associated connector.", + } - result = request_approval( - action_type="dropbox_file_trash", - tool_name="delete_dropbox_file", - params={ - "file_path": file_path, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + meta = document.document_metadata or {} + file_path = meta.get("dropbox_path") + file_id = meta.get("dropbox_file_id") + document_id = document.id - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if not file_path: + return { + "status": "error", + "message": "File path is missing. Please re-index the file.", + } - final_file_path = result.params.get("file_path", file_path) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if final_connector_id != connector.id: - result = await db_session.execute( + conn_result = await db_session.execute( select(SearchSourceConnector).filter( and_( - SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type @@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool( ) ) ) - validated_connector = result.scalars().first() - if not validated_connector: + connector = conn_result.scalars().first() + if not connector: return { "status": "error", - "message": "Selected Dropbox connector is invalid or has been disconnected.", + "message": "Dropbox connector not found or access denied.", } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - logger.info( - f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" - ) + cfg = connector.config or {} + if cfg.get("auth_expired"): + return { + "status": "auth_error", + "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "dropbox", + } - client = DropboxClient(session=db_session, connector_id=actual_connector_id) - await client.delete_file(final_file_path) + context = { + "file": { + "file_id": file_id, + "file_path": file_path, + "name": file_name, + "document_id": document_id, + }, + "account": { + "id": connector.id, + "name": connector.name, + "user_email": cfg.get("user_email"), + }, + } - logger.info(f"Dropbox file deleted: path={final_file_path}") - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": file_id, - "message": f"Successfully deleted '{file_name}' from Dropbox.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File deleted, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + result = request_approval( + action_type="dropbox_file_trash", + tool_name="delete_dropbox_file", + params={ + "file_path": file_path, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_file_path = result.params.get("file_path", file_path) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if final_connector_id != connector.id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + and_( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id + == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DROPBOX_CONNECTOR, + ) + ) + ) + validated_connector = result.scalars().first() + if not validated_connector: + return { + "status": "error", + "message": "Selected Dropbox connector is invalid or has been disconnected.", + } + actual_connector_id = validated_connector.id + else: + actual_connector_id = connector.id + + logger.info( + f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" + ) + + client = DropboxClient( + session=db_session, connector_id=actual_connector_id + ) + await client.delete_file(final_file_path) + + logger.info(f"Dropbox file deleted: path={final_file_path}") + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": file_id, + "message": f"Successfully deleted '{file_name}' from Dropbox.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + doc = doc_result.scalars().first() + if doc: + await db_session.delete(doc) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File deleted, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index 3803fa39c..9e287ac51 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -31,6 +31,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -49,12 +50,16 @@ _PROVIDER_MAP = { } +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + prefix = _resolve_provider_prefix(provider, custom_provider) return f"{prefix}/{model_name}" @@ -146,14 +151,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -175,14 +184,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py new file mode 100644 index 000000000..0ca1191a4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py @@ -0,0 +1,41 @@ +from typing import Any + +from app.db import SearchSourceConnector +from app.services.composio_service import ComposioService + + +def split_recipients(value: str | None) -> list[str]: + if not value: + return [] + return [recipient.strip() for recipient in value.split(",") if recipient.strip()] + + +def unwrap_composio_data(data: Any) -> Any: + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + return data + + +async def execute_composio_gmail_tool( + connector: SearchSourceConnector, + user_id: str, + tool_name: str, + params: dict[str, Any], +) -> tuple[Any, str | None]: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return None, "Composio connected account ID not found for this Gmail connector." + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Gmail error") + + return unwrap_composio_data(result.get("data")), None diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index 0bd044695..c88b48d2d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_create_gmail_draft_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_gmail_draft tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_gmail_draft tool + """ + del db_session # per-call session — see docstring + @tool async def create_gmail_draft( to: str, @@ -57,246 +75,276 @@ def create_create_gmail_draft_tool( """ logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_draft_creation", - tool_name="create_gmail_draft", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Gmail accounts have expired authentication") return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - actual_connector_id = connector.id - logger.info( - f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) + logger.info( + f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" + ) + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, + ) - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", + } - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .create(userId="me", body={"message": {"raw": raw}}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - try: - from sqlalchemy.orm.attributes import flag_modified + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + actual_connector_id = connector.id - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + logger.info( + f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + + created, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_CREATE_EMAIL_DRAFT", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(created, dict): + created = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .create(userId="me", body={"message": {"raw": raw}}) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info(f"Gmail draft created: id={created.get('id')}") + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService + logger.info(f"Gmail draft created: id={created.get('id')}") - kb_service = GmailKBSyncService(db_session) - draft_message = created.get("message", {}) - kb_result = await kb_service.sync_after_create( - message_id=draft_message.get("id", ""), - thread_id=draft_message.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - draft_id=created.get("id"), - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: + kb_message_suffix = "" + try: + from app.services.gmail import GmailKBSyncService + + kb_service = GmailKBSyncService(db_session) + draft_message = created.get("message", {}) + kb_result = await kb_service.sync_after_create( + message_id=draft_message.get("id", ""), + thread_id=draft_message.get("threadId", ""), + subject=final_subject, + sender="me", + date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + body_text=final_body, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + draft_id=created.get("id"), + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "draft_id": created.get("id"), - "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } + return { + "status": "success", + "draft_id": created.get("id"), + "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py index deec1627c..464713591 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -5,7 +5,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -20,6 +20,23 @@ def create_read_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def read_gmail_email(message_id: str) -> dict[str, Any]: """Read the full content of a specific Gmail email by its message ID. @@ -32,60 +49,115 @@ def create_read_gmail_email_tool( Returns: Dictionary with status and the full email content formatted as markdown. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Gmail tool not properly configured."} try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - from app.agents.new_chat.tools.gmail.search_emails import _build_credentials - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - detail, error = await gmail.get_message_details(message_id) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): + connector = result.scalars().first() + if not connector: return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } - return {"status": "error", "message": error} - if not detail: + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _format_gmail_summary, + ) + from app.services.composio_service import ComposioService + + service = ComposioService() + detail, error = await service.get_gmail_message_detail( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if error: + return {"status": "error", "message": error} + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + summary = _format_gmail_summary(detail) + content = ( + f"# {summary['subject']}\n\n" + f"**From:** {summary['from']}\n" + f"**To:** {summary['to']}\n" + f"**Date:** {summary['date']}\n\n" + f"## Message Content\n\n" + f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {summary['message_id']}\n" + f"- **Thread ID:** {summary['thread_id']}\n" + ) + return { + "status": "success", + "message_id": summary["message_id"] or message_id, + "content": content, + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _build_credentials, + ) + + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + detail, error = await gmail.get_message_details(message_id) + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } + return {"status": "error", "message": error} + + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + content = gmail.format_message_to_markdown(detail) + return { - "status": "not_found", - "message": f"Email with ID '{message_id}' not found.", + "status": "success", + "message_id": message_id, + "content": content, } - content = gmail.format_message_to_markdown(detail) - - return {"status": "success", "message_id": message_id, "content": content} - except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index 2e363609e..3ce154c53 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -6,7 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector): from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - raise ValueError("Composio connected account ID not found.") - return build_composio_credentials(cca_id) + raise ValueError("Composio connectors must use Composio tool execution.") from google.oauth2.credentials import Credentials @@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector): ) +def _gmail_headers(message: dict[str, Any]) -> dict[str, str]: + headers = message.get("payload", {}).get("headers", []) + return { + header.get("name", "").lower(): header.get("value", "") + for header in headers + if isinstance(header, dict) + } + + +def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]: + headers = _gmail_headers(message) + return { + "message_id": message.get("id") or message.get("messageId"), + "thread_id": message.get("threadId"), + "subject": message.get("subject") or headers.get("subject", "No Subject"), + "from": message.get("sender") or headers.get("from", "Unknown"), + "to": message.get("to") or headers.get("to", ""), + "date": message.get("messageTimestamp") or headers.get("date", ""), + "snippet": message.get("snippet") or message.get("messageText", "")[:300], + "labels": message.get("labelIds", []), + } + + +async def _search_composio_gmail( + connector: SearchSourceConnector, + user_id: str, + query: str, + max_results: int, +) -> dict[str, Any]: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.services.composio_service import ComposioService + + service = ComposioService() + messages, _next_token, _estimate, error = await service.get_gmail_messages( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=max_results, + ) + if error: + return {"status": "error", "message": error} + + emails = [_format_gmail_summary(message) for message in messages] + return { + "status": "success", + "emails": emails, + "total": len(emails), + "message": "No emails found." if not emails else None, + } + + def create_search_gmail_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the search_gmail tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured search_gmail tool + """ + del db_session # per-call session — see docstring + @tool async def search_gmail( query: str, @@ -90,83 +159,92 @@ def create_search_gmail_tool( Dictionary with status and a list of email summaries including message_id, subject, from, date, snippet. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Gmail tool not properly configured."} max_results = min(max_results, 20) try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - messages_list, error = await gmail.get_messages_list( - max_results=max_results, query=query - ) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): + connector = result.scalars().first() + if not connector: return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } - return {"status": "error", "message": error} - if not messages_list: - return { - "status": "success", - "emails": [], - "total": 0, - "message": "No emails found.", - } + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + return await _search_composio_gmail( + connector, str(user_id), query, max_results + ) - emails = [] - for msg in messages_list: - detail, err = await gmail.get_message_details(msg["id"]) - if err: - continue - headers = { - h["name"].lower(): h["value"] - for h in detail.get("payload", {}).get("headers", []) - } - emails.append( - { - "message_id": detail.get("id"), - "thread_id": detail.get("threadId"), - "subject": headers.get("subject", "No Subject"), - "from": headers.get("from", "Unknown"), - "to": headers.get("to", ""), - "date": headers.get("date", ""), - "snippet": detail.get("snippet", ""), - "labels": detail.get("labelIds", []), - } + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, ) - return {"status": "success", "emails": emails, "total": len(emails)} + messages_list, error = await gmail.get_messages_list( + max_results=max_results, query=query + ) + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } + return {"status": "error", "message": error} + + if not messages_list: + return { + "status": "success", + "emails": [], + "total": 0, + "message": "No emails found.", + } + + emails = [] + for msg in messages_list: + detail, err = await gmail.get_message_details(msg["id"]) + if err: + continue + headers = { + h["name"].lower(): h["value"] + for h in detail.get("payload", {}).get("headers", []) + } + emails.append( + { + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + } + ) + + return {"status": "success", "emails": emails, "total": len(emails)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index c3f0999f4..4d5aa3bcc 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_send_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def send_gmail_email( to: str, @@ -58,247 +76,277 @@ def create_send_gmail_email_tool( """ logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_email_send", - tool_name="send_gmail_email", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Gmail accounts have expired authentication") return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - actual_connector_id = connector.id - logger.info( - f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) + logger.info( + f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" + ) + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, + ) - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", + } - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - sent = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .send(userId="me", body={"raw": raw}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - try: - from sqlalchemy.orm.attributes import flag_modified + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + actual_connector_id = connector.id - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + logger.info( + f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + + sent, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_SEND_EMAIL", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(sent, dict): + sent = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + sent = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .send(userId="me", body={"raw": raw}) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info( - f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" - ) + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService - - kb_service = GmailKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - message_id=sent.get("id", ""), - thread_id=sent.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after send failed: {kb_err}") - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "message_id": sent.get("id"), - "thread_id": sent.get("threadId"), - "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", - } + kb_message_suffix = "" + try: + from app.services.gmail import GmailKBSyncService + + kb_service = GmailKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + message_id=sent.get("id", ""), + thread_id=sent.get("threadId", ""), + subject=final_subject, + sender="me", + date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + body_text=final_body, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after send failed: {kb_err}") + kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "message_id": sent.get("id"), + "thread_id": sent.get("threadId"), + "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index 1f1f6227a..95f5b4e6c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -7,6 +7,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -17,6 +18,23 @@ def create_trash_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the trash_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured trash_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def trash_gmail_email( email_subject_or_id: str, @@ -55,244 +73,261 @@ def create_trash_gmail_email_tool( f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, email_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Email not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_trash_context( + search_space_id, user_id, email_subject_or_id ) - return { - "status": "auth_error", - "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Email not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch trash context: {error_msg}") + return {"status": "error", "message": error_msg} - if not message_id: - return { - "status": "error", - "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Gmail account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", + } - logger.info( - f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="gmail_email_trash", - tool_name="trash_gmail_email", - params={ - "message_id": message_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + email = context["email"] + message_id = email["message_id"] + document_id = email.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", - } - - final_message_id = result.params.get("message_id", message_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this email.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not message_id: return { "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", + "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", } - else: - from google.oauth2.credentials import Credentials - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="gmail_email_trash", + tool_name="trash_gmail_email", + params={ + "message_id": message_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .trash(userId="me", id=final_message_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", } - raise - logger.info(f"Gmail email trashed: message_id={final_message_id}") - - trash_result: dict[str, Any] = { - "status": "success", - "message_id": final_message_id, - "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"Email trashed, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + final_message_id = result.params.get("message_id", message_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb ) - return trash_result + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this email.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + + logger.info( + f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + _trashed, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_MOVE_TO_TRASH", + {"user_id": "me", "message_id": final_message_id}, + ) + if error: + raise RuntimeError(error) + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .trash(userId="me", id=final_message_id) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Gmail email trashed: message_id={final_message_id}") + + trash_result: dict[str, Any] = { + "status": "success", + "message_id": final_message_id, + "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"Email trashed, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 91178cd21..129b7defb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_update_gmail_draft_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the update_gmail_draft tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_gmail_draft tool + """ + del db_session # per-call session — see docstring + @tool async def update_gmail_draft( draft_subject_or_id: str, @@ -76,294 +94,329 @@ def create_update_gmail_draft_tool( f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, draft_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Draft not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = account["id"] - draft_id_from_context = context.get("draft_id") - - original_subject = email.get("subject", draft_subject_or_id) - final_subject_default = subject if subject else original_subject - final_to_default = to if to else "" - - logger.info( - f"Requesting approval for updating Gmail draft: '{original_subject}' " - f"(message_id={message_id}, draft_id={draft_id_from_context})" - ) - result = request_approval( - action_type="gmail_draft_update", - tool_name="update_gmail_draft", - params={ - "message_id": message_id, - "draft_id": draft_id_from_context, - "to": final_to_default, - "subject": final_subject_default, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", final_to_default) - final_subject = result.params.get("subject", final_subject_default) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_draft_id = result.params.get("draft_id", draft_id_from_context) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this draft.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, draft_subject_or_id ) - from googleapiclient.discovery import build + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Draft not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - gmail_service = build("gmail", "v1", credentials=creds) - - # Resolve draft_id if not already available - if not final_draft_id: - logger.info( - f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" - ) - final_draft_id = await _find_draft_id_by_message( - gmail_service, message_id - ) - - if not final_draft_id: - return { - "status": "error", - "message": ( - "Could not find this draft in Gmail. " - "It may have already been sent or deleted." - ), - } - - message = MIMEText(final_body) - if final_to: - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .update( - userId="me", - id=final_draft_id, - body={"message": {"raw": raw}}, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" + "Gmail account %s has expired authentication", + account.get("id"), ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - if isinstance(api_err, HttpError) and api_err.resp.status == 404: + + email = context["email"] + message_id = email["message_id"] + document_id = email.get("document_id") + connector_id_from_context = account["id"] + draft_id_from_context = context.get("draft_id") + + original_subject = email.get("subject", draft_subject_or_id) + final_subject_default = subject if subject else original_subject + final_to_default = to if to else "" + + logger.info( + f"Requesting approval for updating Gmail draft: '{original_subject}' " + f"(message_id={message_id}, draft_id={draft_id_from_context})" + ) + result = request_approval( + action_type="gmail_draft_update", + tool_name="update_gmail_draft", + params={ + "message_id": message_id, + "draft_id": draft_id_from_context, + "to": final_to_default, + "subject": final_subject_default, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", + } + + final_to = result.params.get("to", final_to_default) + final_subject = result.params.get("subject", final_subject_default) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_draft_id = result.params.get("draft_id", draft_id_from_context) + + if not final_connector_id: return { "status": "error", - "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + "message": "No connector found for this draft.", } - raise - logger.info(f"Gmail draft updated: id={updated.get('id')}") + from sqlalchemy.future import select - kb_message_suffix = "" - if document_id: - try: - from sqlalchemy.future import select as sa_select - from sqlalchemy.orm.attributes import flag_modified + from app.db import SearchSourceConnector, SearchSourceConnectorType - from app.db import Document + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] - doc_result = await db_session.execute( - sa_select(Document).filter(Document.id == document_id) + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - document = doc_result.scalars().first() - if document: - document.source_markdown = final_body - document.title = final_subject - meta = dict(document.document_metadata or {}) - meta["subject"] = final_subject - meta["draft_id"] = updated.get("id", final_draft_id) - updated_msg = updated.get("message", {}) - if updated_msg.get("id"): - meta["message_id"] = updated_msg["id"] - document.document_metadata = meta - flag_modified(document, "document_metadata") - await db_session.commit() - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - logger.info( - f"KB document {document_id} updated for draft {final_draft_id}" + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + + logger.info( + f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + # Resolve draft_id if not already available + if not final_draft_id: + logger.info( + f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" + ) + if is_composio_gmail: + final_draft_id = await _find_composio_draft_id_by_message( + connector, user_id, message_id ) else: - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB update after draft edit failed: {kb_err}") - await db_session.rollback() - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + from googleapiclient.discovery import build - return { - "status": "success", - "draft_id": updated.get("id"), - "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } + gmail_service = build("gmail", "v1", credentials=creds) + final_draft_id = await _find_draft_id_by_message( + gmail_service, message_id + ) + + if not final_draft_id: + return { + "status": "error", + "message": ( + "Could not find this draft in Gmail. " + "It may have already been sent or deleted." + ), + } + + message = MIMEText(final_body) + if final_to: + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, + ) + + updated, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_UPDATE_DRAFT", + { + "user_id": "me", + "draft_id": final_draft_id, + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(updated, dict): + updated = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .update( + userId="me", + id=final_draft_id, + body={"message": {"raw": raw}}, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + if isinstance(api_err, HttpError) and api_err.resp.status == 404: + return { + "status": "error", + "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + } + raise + + logger.info(f"Gmail draft updated: id={updated.get('id')}") + + kb_message_suffix = "" + if document_id: + try: + from sqlalchemy.future import select as sa_select + from sqlalchemy.orm.attributes import flag_modified + + from app.db import Document + + doc_result = await db_session.execute( + sa_select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + document.source_markdown = final_body + document.title = final_subject + meta = dict(document.document_metadata or {}) + meta["subject"] = final_subject + meta["draft_id"] = updated.get("id", final_draft_id) + updated_msg = updated.get("message", {}) + if updated_msg.get("id"): + meta["message_id"] = updated_msg["id"] + document.document_metadata = meta + flag_modified(document, "document_metadata") + await db_session.commit() + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + logger.info( + f"KB document {document_id} updated for draft {final_draft_id}" + ) + else: + kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB update after draft edit failed: {kb_err}") + await db_session.rollback() + kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + + return { + "status": "success", + "draft_id": updated.get("id"), + "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt @@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str except Exception as e: logger.warning(f"Failed to look up draft by message_id: {e}") return None + + +async def _find_composio_draft_id_by_message( + connector: Any, user_id: str, message_id: str +) -> str | None: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + page_token = "" + while True: + params: dict[str, Any] = { + "user_id": "me", + "max_results": 100, + "verbose": False, + } + if page_token: + params["page_token"] = page_token + + data, error = await execute_composio_gmail_tool( + connector, user_id, "GMAIL_LIST_DRAFTS", params + ) + if error or not isinstance(data, dict): + return None + + for draft in data.get("drafts", []): + if draft.get("message", {}).get("id") == message_id: + return draft.get("id") + + page_token = data.get("nextPageToken") or data.get("next_page_token") or "" + if not page_token: + return None diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py index 37bcf083e..dec92cc8b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_create_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def create_calendar_event( summary: str, @@ -60,254 +78,294 @@ def create_create_calendar_event_tool( f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning( - "All Google Calendar accounts have expired authentication" + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - return { - "status": "auth_error", - "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - logger.info( - f"Requesting approval for creating calendar event: summary='{summary}'" - ) - result = request_approval( - action_type="google_calendar_event_creation", - tool_name="create_calendar_event", - params={ - "summary": summary, - "start_datetime": start_datetime, - "end_datetime": end_datetime, - "description": description, - "location": location, - "attendees": attendees, - "timezone": context.get("timezone"), - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", - } - - final_summary = result.params.get("summary", summary) - final_start_datetime = result.params.get("start_datetime", start_datetime) - final_end_datetime = result.params.get("end_datetime", end_datetime) - final_description = result.params.get("description", description) - final_location = result.params.get("location", location) - final_attendees = result.params.get("attendees", attendees) - final_connector_id = result.params.get("connector_id") - - if not final_summary or not final_summary.strip(): - return {"status": "error", "message": "Event summary cannot be empty."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning( + "All Google Calendar accounts have expired authentication" ) + return { + "status": "auth_error", + "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } + + logger.info( + f"Requesting approval for creating calendar event: summary='{summary}'" ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - else: - config_data = dict(connector.config) - - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + result = request_approval( + action_type="google_calendar_event_creation", + tool_name="create_calendar_event", + params={ + "summary": summary, + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "description": description, + "location": location, + "attendees": attendees, + "timezone": context.get("timezone"), + "connector_id": None, + }, + context=context, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", + } - tz = context.get("timezone", "UTC") - event_body: dict[str, Any] = { - "summary": final_summary, - "start": {"dateTime": final_start_datetime, "timeZone": tz}, - "end": {"dateTime": final_end_datetime, "timeZone": tz}, - } - if final_description: - event_body["description"] = final_description - if final_location: - event_body["location"] = final_location - if final_attendees: - event_body["attendees"] = [ - {"email": e.strip()} for e in final_attendees if e.strip() + final_summary = result.params.get("summary", summary) + final_start_datetime = result.params.get( + "start_datetime", start_datetime + ) + final_end_datetime = result.params.get("end_datetime", end_datetime) + final_description = result.params.get("description", description) + final_location = result.params.get("location", location) + final_attendees = result.params.get("attendees", attendees) + final_connector_id = result.params.get("connector_id") + + if not final_summary or not final_summary.strip(): + return { + "status": "error", + "message": "Event summary cannot be empty.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, ] - try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .insert(calendarId="primary", body=event_body) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_calendar import GoogleCalendarKBSyncService - - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - event_id=created.get("id"), - event_summary=final_summary, - calendar_id="primary", - start_time=final_start_datetime, - end_time=final_end_datetime, - location=final_location, - html_link=created.get("htmlLink"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } + actual_connector_id = connector.id - return { - "status": "success", - "event_id": created.get("id"), - "html_link": created.get("htmlLink"), - "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", - } + logger.info( + f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" + ) + + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + tz = context.get("timezone", "UTC") + event_body: dict[str, Any] = { + "summary": final_summary, + "start": {"dateTime": final_start_datetime, "timeZone": tz}, + "end": {"dateTime": final_end_datetime, "timeZone": tz}, + } + if final_description: + event_body["description"] = final_description + if final_location: + event_body["location"] = final_location + if final_attendees: + event_body["attendees"] = [ + {"email": e.strip()} for e in final_attendees if e.strip() + ] + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params = { + "calendar_id": "primary", + "summary": final_summary, + "start_datetime": final_start_datetime, + "end_datetime": final_end_datetime, + "timezone": tz, + "attendees": final_attendees or [], + } + if final_description: + composio_params["description"] = final_description + if final_location: + composio_params["location"] = final_location + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_CREATE_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + created = composio_result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .insert(calendarId="primary", body=event_body) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" + ) + + kb_message_suffix = "" + try: + from app.services.google_calendar import GoogleCalendarKBSyncService + + kb_service = GoogleCalendarKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + event_id=created.get("id"), + event_summary=final_summary, + calendar_id="primary", + start_time=final_start_datetime, + end_time=final_end_datetime, + location=final_location, + html_link=created.get("htmlLink"), + description=final_description, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "event_id": created.get("id"), + "html_link": created.get("htmlLink"), + "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py index 4d9d69b4b..e7e891b08 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_delete_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def delete_calendar_event( event_title_or_id: str, @@ -54,240 +72,258 @@ def create_delete_calendar_event_tool( f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch deletion context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Calendar account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, event_title_or_id ) - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Event not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch deletion context: {error_msg}") + return {"status": "error", "message": error_msg} - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Google Calendar account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } - logger.info( - f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_calendar_event_deletion", - tool_name="delete_calendar_event", - params={ - "event_id": event_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + event = context["event"] + event_id = event["event_id"] + document_id = event.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - - actual_connector_id = connector.id - - logger.info( - f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not event_id: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", } - else: - config_data = dict(connector.config) - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="google_calendar_event_deletion", + tool_name="delete_calendar_event", + params={ + "event_id": event_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .delete(calendarId="primary", eventId=final_event_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", } - raise - logger.info(f"Calendar event deleted: event_id={final_event_id}") - - delete_result: dict[str, Any] = { - "status": "success", - "event_id": final_event_id, - "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - delete_result["warning"] = ( - f"Event deleted, but failed to remove from knowledge base: {e!s}" - ) - - delete_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - delete_result["message"] = ( - f"{delete_result.get('message', '')} (also removed from knowledge base)" + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb ) - return delete_result + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this event.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", + } + + actual_connector_id = connector.id + + logger.info( + f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" + ) + + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_DELETE_EVENT", + params={ + "calendar_id": "primary", + "event_id": final_event_id, + }, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .delete(calendarId="primary", eventId=final_event_id) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Calendar event deleted: event_id={final_event_id}") + + delete_result: dict[str, Any] = { + "status": "success", + "event_id": final_event_id, + "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + delete_result["warning"] = ( + f"Event deleted, but failed to remove from knowledge base: {e!s}" + ) + + delete_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + delete_result["message"] = ( + f"{delete_result.get('message', '')} (also removed from knowledge base)" + ) + + return delete_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index dc6adb822..e5f18f675 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.agents.new_chat.tools.gmail.search_emails import _build_credentials -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -16,11 +16,57 @@ _CALENDAR_TYPES = [ ] +def _to_calendar_boundary(value: str, *, is_end: bool) -> str: + if "T" in value: + return value + time = "23:59:59" if is_end else "00:00:00" + return f"{value}T{time}Z" + + +def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + events = [] + for ev in events_raw: + start = ev.get("start", {}) + end = ev.get("end", {}) + attendees_raw = ev.get("attendees", []) + events.append( + { + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [a.get("email", "") for a in attendees_raw[:10]], + "status": ev.get("status", ""), + } + ) + return events + + def create_search_calendar_events_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the search_calendar_events tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured search_calendar_events tool + """ + del db_session # per-call session — see docstring + @tool async def search_calendar_events( start_date: str, @@ -38,7 +84,7 @@ def create_search_calendar_events_tool( Dictionary with status and a list of events including event_id, summary, start, end, location, attendees. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Calendar tool not properly configured.", @@ -47,76 +93,85 @@ def create_search_calendar_events_tool( max_results = min(max_results, 50) try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } - creds = _build_credentials(connector) - - from app.connectors.google_calendar_connector import GoogleCalendarConnector - - cal = GoogleCalendarConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - events_raw, error = await cal.get_all_primary_calendar_events( - start_date=start_date, - end_date=end_date, - max_results=max_results, - ) - - if error: if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - return { - "status": "auth_error", - "message": error, - "connector_type": "google_calendar", - } - if "no events found" in error.lower(): - return { - "status": "success", - "events": [], - "total": 0, - "message": error, - } - return {"status": "error", "message": error} + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } - events = [] - for ev in events_raw: - start = ev.get("start", {}) - end = ev.get("end", {}) - attendees_raw = ev.get("attendees", []) - events.append( - { - "event_id": ev.get("id"), - "summary": ev.get("summary", "No Title"), - "start": start.get("dateTime") or start.get("date", ""), - "end": end.get("dateTime") or end.get("date", ""), - "location": ev.get("location", ""), - "description": ev.get("description", ""), - "html_link": ev.get("htmlLink", ""), - "attendees": [a.get("email", "") for a in attendees_raw[:10]], - "status": ev.get("status", ""), - } - ) + from app.services.composio_service import ComposioService - return {"status": "success", "events": events, "total": len(events)} + events_raw, error = await ComposioService().get_calendar_events( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + time_min=_to_calendar_boundary(start_date, is_end=False), + time_max=_to_calendar_boundary(end_date, is_end=True), + max_results=max_results, + ) + if not events_raw and not error: + error = "No events found in the specified date range." + else: + creds = _build_credentials(connector) + + from app.connectors.google_calendar_connector import ( + GoogleCalendarConnector, + ) + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) + + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "google_calendar", + } + if "no events found" in error.lower(): + return { + "status": "success", + "events": [], + "total": 0, + "message": error, + } + return {"status": "error", "message": error} + + events = _format_calendar_events(events_raw) + + return {"status": "success", "events": events, "total": len(events)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 259f52bba..b8561fee6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -33,6 +34,23 @@ def create_update_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the update_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def update_calendar_event( event_title_or_id: str, @@ -74,272 +92,317 @@ def create_update_calendar_event_tool( """ logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - if context.get("auth_expired"): - logger.warning("Google Calendar account has expired authentication") - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } - - logger.info( - f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" - ) - result = request_approval( - action_type="google_calendar_event_update", - tool_name="update_calendar_event", - params={ - "event_id": event_id, - "document_id": document_id, - "connector_id": connector_id_from_context, - "new_summary": new_summary, - "new_start_datetime": new_start_datetime, - "new_end_datetime": new_end_datetime, - "new_description": new_description, - "new_location": new_location, - "new_attendees": new_attendees, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_new_summary = result.params.get("new_summary", new_summary) - final_new_start_datetime = result.params.get( - "new_start_datetime", new_start_datetime - ) - final_new_end_datetime = result.params.get( - "new_end_datetime", new_end_datetime - ) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_location = result.params.get("new_location", new_location) - final_new_attendees = result.params.get("new_attendees", new_attendees) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, event_title_or_id ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Event not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - logger.info( - f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) + if context.get("auth_expired"): + logger.warning("Google Calendar account has expired authentication") + return { + "status": "auth_error", + "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials + event = context["event"] + event_id = event["event_id"] + document_id = event.get("document_id") + connector_id_from_context = context["account"]["id"] - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not event_id: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", } - else: - config_data = dict(connector.config) - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" + ) + result = request_approval( + action_type="google_calendar_event_update", + tool_name="update_calendar_event", + params={ + "event_id": event_id, + "document_id": document_id, + "connector_id": connector_id_from_context, + "new_summary": new_summary, + "new_start_datetime": new_start_datetime, + "new_end_datetime": new_end_datetime, + "new_description": new_description, + "new_location": new_location, + "new_attendees": new_attendees, + }, + context=context, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", + } - update_body: dict[str, Any] = {} - if final_new_summary is not None: - update_body["summary"] = final_new_summary - if final_new_start_datetime is not None: - update_body["start"] = _build_time_body( - final_new_start_datetime, context + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context ) - if final_new_end_datetime is not None: - update_body["end"] = _build_time_body(final_new_end_datetime, context) - if final_new_description is not None: - update_body["description"] = final_new_description - if final_new_location is not None: - update_body["location"] = final_new_location - if final_new_attendees is not None: - update_body["attendees"] = [ - {"email": e.strip()} for e in final_new_attendees if e.strip() + final_new_summary = result.params.get("new_summary", new_summary) + final_new_start_datetime = result.params.get( + "new_start_datetime", new_start_datetime + ) + final_new_end_datetime = result.params.get( + "new_end_datetime", new_end_datetime + ) + final_new_description = result.params.get( + "new_description", new_description + ) + final_new_location = result.params.get("new_location", new_location) + final_new_attendees = result.params.get("new_attendees", new_attendees) + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this event.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, ] - if not update_body: - return { - "status": "error", - "message": "No changes specified. Please provide at least one field to update.", - } - - try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .patch( - calendarId="primary", - eventId=final_event_id, - body=update_body, - ) - .execute() - ), + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) + connector = result.scalars().first() + if not connector: return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", } - raise - logger.info(f"Calendar event updated: event_id={final_event_id}") + actual_connector_id = connector.id - kb_message_suffix = "" - if document_id is not None: - try: - from app.services.google_calendar import GoogleCalendarKBSyncService + logger.info( + f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" + ) - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - event_id=final_event_id, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - return { - "status": "success", - "event_id": final_event_id, - "html_link": updated.get("htmlLink"), - "message": f"Successfully updated the calendar event.{kb_message_suffix}", - } + update_body: dict[str, Any] = {} + if final_new_summary is not None: + update_body["summary"] = final_new_summary + if final_new_start_datetime is not None: + update_body["start"] = _build_time_body( + final_new_start_datetime, context + ) + if final_new_end_datetime is not None: + update_body["end"] = _build_time_body( + final_new_end_datetime, context + ) + if final_new_description is not None: + update_body["description"] = final_new_description + if final_new_location is not None: + update_body["location"] = final_new_location + if final_new_attendees is not None: + update_body["attendees"] = [ + {"email": e.strip()} for e in final_new_attendees if e.strip() + ] + + if not update_body: + return { + "status": "error", + "message": "No changes specified. Please provide at least one field to update.", + } + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params: dict[str, Any] = { + "calendar_id": "primary", + "event_id": final_event_id, + } + if final_new_summary is not None: + composio_params["summary"] = final_new_summary + if final_new_start_datetime is not None: + composio_params["start_time"] = final_new_start_datetime + if final_new_end_datetime is not None: + composio_params["end_time"] = final_new_end_datetime + if final_new_description is not None: + composio_params["description"] = final_new_description + if final_new_location is not None: + composio_params["location"] = final_new_location + if final_new_attendees is not None: + composio_params["attendees"] = [ + e.strip() for e in final_new_attendees if e.strip() + ] + if not _is_date_only( + final_new_start_datetime or final_new_end_datetime or "" + ): + composio_params["timezone"] = context.get("timezone", "UTC") + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_PATCH_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + updated = composio_result.get("data", {}) + if isinstance(updated, dict): + updated = updated.get("data", updated) + if isinstance(updated, dict): + updated = updated.get("response_data", updated) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .patch( + calendarId="primary", + eventId=final_event_id, + body=update_body, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Calendar event updated: event_id={final_event_id}") + + kb_message_suffix = "" + if document_id is not None: + try: + from app.services.google_calendar import ( + GoogleCalendarKBSyncService, + ) + + kb_service = GoogleCalendarKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=document_id, + event_id=final_event_id, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") + kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." + + return { + "status": "success", + "event_id": final_event_id, + "html_link": updated.get("htmlLink"), + "message": f"Successfully updated the calendar event.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py index f36db8f3f..66199ca67 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET +from app.db import async_session_maker from app.services.google_drive import GoogleDriveToolMetadataService logger = logging.getLogger(__name__) @@ -23,6 +24,25 @@ def create_create_google_drive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_google_drive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Google Drive connector + user_id: User ID for fetching user-specific context + + Returns: + Configured create_google_drive_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_google_drive_file( name: str, @@ -65,7 +85,7 @@ def create_create_google_drive_file_tool( f"create_google_drive_file called: name='{name}', type='{file_type}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Drive tool not properly configured. Please contact support.", @@ -78,195 +98,232 @@ def create_create_google_drive_file_tool( } try: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = GoogleDriveToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Google Drive accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - - logger.info( - f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" - ) - result = request_approval( - action_type="google_drive_file_creation", - tool_name="create_google_drive_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - mime_type = _MIME_MAP.get(final_file_type) - if not mime_type: - return { - "status": "error", - "message": f"Unsupported file type '{final_file_type}'.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", - } - actual_connector_id = connector.id + return {"status": "error", "message": context["error"]} - logger.info( - f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" - ) - - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) - - client = GoogleDriveClient( - session=db_session, - connector_id=actual_connector_id, - credentials=pre_built_creds, - ) - try: - created = await client.create_file( - name=final_name, - mime_type=mime_type, - parent_folder_id=final_parent_folder_id, - content=final_content, - ) - except HttpError as http_err: - if http_err.resp.status == 403: + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + "All Google Drive accounts have expired authentication" ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_drive", } - raise - logger.info( - f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_drive import GoogleDriveKBSyncService - - kb_service = GoogleDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=mime_type, - web_view_link=created.get("webViewLink"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" + ) + result = request_approval( + action_type="google_drive_file_creation", + tool_name="create_google_drive_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_view_link": created.get("webViewLink"), - "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", + } + + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") + + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + mime_type = _MIME_MAP.get(final_file_type) + if not mime_type: + return { + "status": "error", + "message": f"Unsupported file type '{final_file_type}'.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _drive_types = [ + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Drive connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", + } + actual_connector_id = connector.id + + logger.info( + f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" + ) + + is_composio_drive = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + if is_composio_drive: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } + client = GoogleDriveClient( + session=db_session, + connector_id=actual_connector_id, + ) + try: + if is_composio_drive: + from app.services.composio_service import ComposioService + + params: dict[str, Any] = { + "name": final_name, + "mimeType": mime_type, + "fields": "id,name,webViewLink,mimeType", + } + if final_parent_folder_id: + params["parents"] = [final_parent_folder_id] + if final_content: + params["description"] = final_content[:4096] + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_CREATE_FILE", + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + created = result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + if not isinstance(created, dict): + created = {} + else: + created = await client.create_file( + name=final_name, + mime_type=mime_type, + parent_folder_id=final_parent_folder_id, + content=final_content, + ) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" + ) + + kb_message_suffix = "" + try: + from app.services.google_drive import GoogleDriveKBSyncService + + kb_service = GoogleDriveKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=created.get("id"), + file_name=created.get("name", final_name), + mime_type=mime_type, + web_view_link=created.get("webViewLink"), + content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": created.get("id"), + "name": created.get("name"), + "web_view_link": created.get("webViewLink"), + "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py index 832afff0d..b3c9240d8 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient +from app.db import async_session_maker from app.services.google_drive import GoogleDriveToolMetadataService logger = logging.getLogger(__name__) @@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_google_drive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Google Drive connector + user_id: User ID for fetching user-specific context + + Returns: + Configured delete_google_drive_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_google_drive_file( file_name: str, @@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool( f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Drive tool not properly configured. Please contact support.", } try: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, file_name - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"File not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Drive account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GoogleDriveToolMetadataService(db_session) + context = await metadata_service.get_trash_context( + search_space_id, user_id, file_name ) - return { - "status": "auth_error", - "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - file = context["file"] - file_id = file["file_id"] - document_id = file.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"File not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch trash context: {error_msg}") + return {"status": "error", "message": error_msg} - if not file_id: - return { - "status": "error", - "message": "File ID is missing from the indexed document. Please re-index the file and try again.", - } - - logger.info( - f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_drive_file_trash", - tool_name="delete_google_drive_file", - params={ - "file_id": file_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", - } - - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this file.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - - logger.info( - f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" - ) - - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) - - client = GoogleDriveClient( - session=db_session, - connector_id=connector.id, - credentials=pre_built_creds, - ) - try: - await client.trash_file(file_id=final_file_id) - except HttpError as http_err: - if http_err.resp.status == 403: + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"Insufficient permissions for connector {connector.id}: {http_err}" + "Google Drive account %s has expired authentication", + account.get("id"), ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_drive", } - raise - logger.info( - f"Google Drive file deleted (moved to trash): file_id={final_file_id}" - ) + file = context["file"] + file_id = file["file_id"] + document_id = file.get("document_id") + connector_id_from_context = context["account"]["id"] - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file['name']}' to trash.", - } + if not file_id: + return { + "status": "error", + "message": "File ID is missing from the indexed document. Please re-index the file and try again.", + } - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to trash, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + logger.info( + f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="google_drive_file_trash", + tool_name="delete_google_drive_file", + params={ + "file_id": file_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", + } + + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this file.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _drive_types = [ + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Drive connector is invalid or has been disconnected.", + } + + logger.info( + f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" + ) + + is_composio_drive = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + if is_composio_drive: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } + + client = GoogleDriveClient( + session=db_session, + connector_id=connector.id, + ) + try: + if is_composio_drive: + from app.services.composio_service import ComposioService + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_TRASH_FILE", + params={"file_id": final_file_id}, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + else: + await client.trash_file(file_id=final_file_id) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {http_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Google Drive file deleted (moved to trash): file_id={final_file_id}" + ) + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": final_file_id, + "message": f"Successfully moved '{file['name']}' to trash.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File moved to trash, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 92248c2c9..5b64929de 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( { "create_gmail_draft", "update_gmail_draft", + "create_calendar_event", "create_notion_page", "create_confluence_page", "create_google_drive_file", diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py index 8b40dde65..0b04f1642 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,28 @@ def create_create_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the create_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Per-call sessions also + keep the request's outer transaction free of long-running Jira API + blocking. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured create_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def create_jira_issue( project_key: str, @@ -49,158 +72,167 @@ def create_create_jira_issue_tool( f"create_jira_issue called: project_key='{project_key}', summary='{summary}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Jira accounts need re-authentication.", - "connector_type": "jira", - } - - result = request_approval( - action_type="jira_issue_creation", - tool_name="create_jira_issue", - params={ - "project_key": project_key, - "summary": summary, - "issue_type": issue_type, - "description": description, - "priority": priority, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_project_key = result.params.get("project_key", project_key) - final_summary = result.params.get("summary", summary) - final_issue_type = result.params.get("issue_type", issue_type) - final_description = result.params.get("description", description) - final_priority = result.params.get("priority", priority) - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_summary or not final_summary.strip(): - return {"status": "error", "message": "Issue summary cannot be empty."} - if not final_project_key: - return {"status": "error", "message": "A project must be selected."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return {"status": "error", "message": "No Jira connector found."} - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, - ) + + if "error" in context: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Jira accounts need re-authentication.", + "connector_type": "jira", + } + + result = request_approval( + action_type="jira_issue_creation", + tool_name="create_jira_issue", + params={ + "project_key": project_key, + "summary": summary, + "issue_type": issue_type, + "description": description, + "priority": priority, + "connector_id": connector_id, + }, + context=context, ) - connector = result.scalars().first() - if not connector: + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_project_key = result.params.get("project_key", project_key) + final_summary = result.params.get("summary", summary) + final_issue_type = result.params.get("issue_type", issue_type) + final_description = result.params.get("description", description) + final_priority = result.params.get("priority", priority) + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_summary or not final_summary.strip(): return { "status": "error", - "message": "Selected Jira connector is invalid.", + "message": "Issue summary cannot be empty.", } + if not final_project_key: + return {"status": "error", "message": "A project must be selected."} - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - jira_client = await jira_history._get_jira_client() - api_result = await asyncio.to_thread( - jira_client.create_issue, - project_key=final_project_key, - summary=final_summary, - issue_type=final_issue_type, - description=final_description, - priority=final_priority, - ) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + from sqlalchemy.future import select - issue_key = api_result.get("key", "") - issue_url = ( - f"{jira_history._base_url}/browse/{issue_key}" - if jira_history._base_url and issue_key - else "" - ) + from app.db import SearchSourceConnector, SearchSourceConnectorType - kb_message_suffix = "" - try: - from app.services.jira import JiraKBSyncService - - kb_service = JiraKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=issue_key, - issue_identifier=issue_key, - issue_title=final_summary, - description=final_description, - state="To Do", - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Jira connector found.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } - return { - "status": "success", - "issue_key": issue_key, - "issue_url": issue_url, - "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", - } + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=actual_connector_id + ) + jira_client = await jira_history._get_jira_client() + api_result = await asyncio.to_thread( + jira_client.create_issue, + project_key=final_project_key, + summary=final_summary, + issue_type=final_issue_type, + description=final_description, + priority=final_priority, + ) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + _conn = connector + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + issue_key = api_result.get("key", "") + issue_url = ( + f"{jira_history._base_url}/browse/{issue_key}" + if jira_history._base_url and issue_key + else "" + ) + + kb_message_suffix = "" + try: + from app.services.jira import JiraKBSyncService + + kb_service = JiraKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + issue_id=issue_key, + issue_identifier=issue_key, + issue_title=final_summary, + description=final_description, + state="To Do", + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "issue_key": issue_key, + "issue_url": issue_url, + "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py index 6466c80ea..c41aedad9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,26 @@ def create_delete_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the delete_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured delete_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def delete_jira_issue( issue_title_or_key: str, @@ -44,130 +65,136 @@ def create_delete_jira_issue_tool( f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, issue_title_or_key - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "jira", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - issue_data = context["issue"] - issue_key = issue_data["issue_id"] - document_id = issue_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="jira_issue_deletion", - tool_name="delete_jira_issue", - params={ - "issue_key": issue_key, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_key = result.params.get("issue_key", issue_key) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this issue.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, issue_title_or_key ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Jira connector is invalid.", - } - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=final_connector_id + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "jira", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + issue_data = context["issue"] + issue_key = issue_data["issue_id"] + document_id = issue_data["document_id"] + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="jira_issue_deletion", + tool_name="delete_jira_issue", + params={ + "issue_key": issue_key, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - jira_client = await jira_history._get_jira_client() - await asyncio.to_thread(jira_client.delete_issue, final_issue_key) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document + final_issue_key = result.params.get("issue_key", issue_key) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this issue.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } - message = f"Jira issue {final_issue_key} deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + jira_client = await jira_history._get_jira_client() + await asyncio.to_thread(jira_client.delete_issue, final_issue_key) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - return { - "status": "success", - "issue_key": final_issue_key, - "deleted_from_kb": deleted_from_kb, - "message": message, - } + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + + message = f"Jira issue {final_issue_key} deleted successfully." + if deleted_from_kb: + message += " Also removed from the knowledge base." + + return { + "status": "success", + "issue_key": final_issue_key, + "deleted_from_kb": deleted_from_kb, + "message": message, + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py index f6e586a2e..0fd7b28b3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,26 @@ def create_update_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the update_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured update_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def update_jira_issue( issue_title_or_key: str, @@ -48,169 +69,177 @@ def create_update_jira_issue_tool( f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_title_or_key - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "jira", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - issue_data = context["issue"] - issue_key = issue_data["issue_id"] - document_id = issue_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="jira_issue_update", - tool_name="update_jira_issue", - params={ - "issue_key": issue_key, - "document_id": document_id, - "new_summary": new_summary, - "new_description": new_description, - "new_priority": new_priority, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_key = result.params.get("issue_key", issue_key) - final_summary = result.params.get("new_summary", new_summary) - final_description = result.params.get("new_description", new_description) - final_priority = result.params.get("new_priority", new_priority) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this issue.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, issue_title_or_key ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Jira connector is invalid.", - } - fields: dict[str, Any] = {} - if final_summary: - fields["summary"] = final_summary - if final_description is not None: - fields["description"] = { - "type": "doc", - "version": 1, - "content": [ - { - "type": "paragraph", - "content": [{"type": "text", "text": final_description}], + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "jira", } - ], - } - if final_priority: - fields["priority"] = {"name": final_priority} + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} - if not fields: - return {"status": "error", "message": "No changes specified."} + issue_data = context["issue"] + issue_key = issue_data["issue_id"] + document_id = issue_data.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=final_connector_id + result = request_approval( + action_type="jira_issue_update", + tool_name="update_jira_issue", + params={ + "issue_key": issue_key, + "document_id": document_id, + "new_summary": new_summary, + "new_description": new_description, + "new_priority": new_priority, + "connector_id": connector_id_from_context, + }, + context=context, ) - jira_client = await jira_history._get_jira_client() - await asyncio.to_thread( - jira_client.update_issue, final_issue_key, fields - ) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - issue_url = ( - f"{jira_history._base_url}/browse/{final_issue_key}" - if jira_history._base_url and final_issue_key - else "" - ) + final_issue_key = result.params.get("issue_key", issue_key) + final_summary = result.params.get("new_summary", new_summary) + final_description = result.params.get( + "new_description", new_description + ) + final_priority = result.params.get("new_priority", new_priority) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_document_id = result.params.get("document_id", document_id) - kb_message_suffix = "" - if final_document_id: - try: - from app.services.jira import JiraKBSyncService + from sqlalchemy.future import select - kb_service = JiraKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_key, - user_id=user_id, - search_space_id=search_space_id, + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this issue.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } + + fields: dict[str, Any] = {} + if final_summary: + fields["summary"] = final_summary + if final_description is not None: + fields["description"] = { + "type": "doc", + "version": 1, + "content": [ + { + "type": "paragraph", + "content": [ + {"type": "text", "text": final_description} + ], + } + ], + } + if final_priority: + fields["priority"] = {"name": final_priority} + + if not fields: + return {"status": "error", "message": "No changes specified."} + + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + jira_client = await jira_history._get_jira_client() + await asyncio.to_thread( + jira_client.update_issue, final_issue_key, fields + ) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + issue_url = ( + f"{jira_history._base_url}/browse/{final_issue_key}" + if jira_history._base_url and final_issue_key + else "" + ) + + kb_message_suffix = "" + if final_document_id: + try: + from app.services.jira import JiraKBSyncService + + kb_service = JiraKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + issue_id=final_issue_key, + user_id=user_id, + search_space_id=search_space_id, ) - else: + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = ( + " The knowledge base will be updated in the next sync." + ) + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") kb_message_suffix = ( " The knowledge base will be updated in the next sync." ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - return { - "status": "success", - "issue_key": final_issue_key, - "issue_url": issue_url, - "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", - } + return { + "status": "success", + "issue_key": final_issue_key, + "issue_url": issue_url, + "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py index ff254e133..f897bee7a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_create_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the create_linear_issue tool. + """Factory function to create the create_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_create_linear_issue_tool( Returns: Configured create_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def create_linear_issue( @@ -65,7 +73,7 @@ def create_create_linear_issue_tool( """ logger.info(f"create_linear_issue called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -75,160 +83,170 @@ def create_create_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - workspaces = context.get("workspaces", []) - if workspaces and all(w.get("auth_expired") for w in workspaces): - logger.warning("All Linear accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "linear", - } - - logger.info(f"Requesting approval for creating Linear issue: '{title}'") - result = request_approval( - action_type="linear_issue_creation", - tool_name="create_linear_issue", - params={ - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_description = result.params.get("description", description) - final_team_id = result.params.get("team_id") - final_state_id = result.params.get("state_id") - final_assignee_id = result.params.get("assignee_id") - final_priority = result.params.get("priority") - final_label_ids = result.params.get("label_ids") or [] - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return {"status": "error", "message": "Issue title cannot be empty."} - if not final_team_id: - return { - "status": "error", - "message": "A team must be selected to create an issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" + ) + return {"status": "error", "message": context["error"]} + + workspaces = context.get("workspaces", []) + if workspaces and all(w.get("auth_expired") for w in workspaces): + logger.warning("All Linear accounts have expired authentication") + return { + "status": "auth_error", + "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "linear", + } + + logger.info(f"Requesting approval for creating Linear issue: '{title}'") + result = request_approval( + action_type="linear_issue_creation", + tool_name="create_linear_issue", + params={ + "title": title, + "description": description, + "team_id": None, + "state_id": None, + "assignee_id": None, + "priority": None, + "label_ids": [], + "connector_id": connector_id, + }, + context=context, + ) + + if result.rejected: + logger.info("Linear issue creation rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_title = result.params.get("title", title) + final_description = result.params.get("description", description) + final_team_id = result.params.get("team_id") + final_state_id = result.params.get("state_id") + final_assignee_id = result.params.get("assignee_id") + final_priority = result.params.get("priority") + final_label_ids = result.params.get("label_ids") or [] + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_title or not final_title.strip(): + logger.error("Title is empty or contains only whitespace") return { "status": "error", - "message": "No Linear connector found. Please connect Linear in your workspace settings.", + "message": "Issue title cannot be empty.", } - actual_connector_id = connector.id - logger.info(f"Found Linear connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: + if not final_team_id: return { "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", + "message": "A team must be selected to create an issue.", } - logger.info(f"Validated Linear connector: id={actual_connector_id}") - logger.info( - f"Creating Linear issue with final params: title='{final_title}'" - ) - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - result = await linear_client.create_issue( - team_id=final_team_id, - title=final_title, - description=final_description, - state_id=final_state_id, - assignee_id=final_assignee_id, - priority=final_priority, - label_ids=final_label_ids if final_label_ids else None, - ) + from sqlalchemy.future import select - if result.get("status") == "error": - logger.error(f"Failed to create Linear issue: {result.get('message')}") - return {"status": "error", "message": result.get("message")} + from app.db import SearchSourceConnector, SearchSourceConnectorType - logger.info( - f"Linear issue created: {result.get('identifier')} - {result.get('title')}" - ) - - kb_message_suffix = "" - try: - from app.services.linear import LinearKBSyncService - - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=result.get("id"), - issue_identifier=result.get("identifier", ""), - issue_title=result.get("title", final_title), - issue_url=result.get("url"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Linear connector found. Please connect Linear in your workspace settings.", + } + actual_connector_id = connector.id + logger.info(f"Found Linear connector: id={actual_connector_id}") else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + logger.info(f"Validated Linear connector: id={actual_connector_id}") - return { - "status": "success", - "issue_id": result.get("id"), - "identifier": result.get("identifier"), - "url": result.get("url"), - "message": (result.get("message", "") + kb_message_suffix), - } + logger.info( + f"Creating Linear issue with final params: title='{final_title}'" + ) + linear_client = LinearConnector( + session=db_session, connector_id=actual_connector_id + ) + result = await linear_client.create_issue( + team_id=final_team_id, + title=final_title, + description=final_description, + state_id=final_state_id, + assignee_id=final_assignee_id, + priority=final_priority, + label_ids=final_label_ids if final_label_ids else None, + ) + + if result.get("status") == "error": + logger.error( + f"Failed to create Linear issue: {result.get('message')}" + ) + return {"status": "error", "message": result.get("message")} + + logger.info( + f"Linear issue created: {result.get('identifier')} - {result.get('title')}" + ) + + kb_message_suffix = "" + try: + from app.services.linear import LinearKBSyncService + + kb_service = LinearKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + issue_id=result.get("id"), + issue_identifier=result.get("identifier", ""), + issue_title=result.get("title", final_title), + issue_url=result.get("url"), + description=final_description, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "issue_id": result.get("id"), + "identifier": result.get("identifier"), + "url": result.get("url"), + "message": (result.get("message", "") + kb_message_suffix), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py index 29ef0cdf2..c5039a8eb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_delete_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the delete_linear_issue tool. + """Factory function to create the delete_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for finding the correct Linear connector connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_delete_linear_issue_tool( Returns: Configured delete_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def delete_linear_issue( @@ -73,7 +81,7 @@ def create_delete_linear_issue_tool( f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -83,149 +91,152 @@ def create_delete_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for delete context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - issue_identifier = context["issue"].get("identifier", "") - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - logger.info( - f"Requesting approval for deleting Linear issue: '{issue_ref}' " - f"(id={issue_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="linear_issue_deletion", - tool_name="delete_linear_issue", - params={ - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Linear issue with final params: issue_id={final_issue_id}, " - f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_delete_context( + search_space_id, user_id, issue_ref ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + logger.warning(f"Auth expired for delete context: {error_msg}") + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "linear", + } + if "not found" in error_msg.lower(): + logger.warning(f"Issue not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + else: + logger.error(f"Failed to fetch delete context: {error_msg}") + return {"status": "error", "message": error_msg} + + issue_id = context["issue"]["id"] + issue_identifier = context["issue"].get("identifier", "") + document_id = context["issue"]["document_id"] + connector_id_from_context = context.get("workspace", {}).get("id") + + logger.info( + f"Requesting approval for deleting Linear issue: '{issue_ref}' " + f"(id={issue_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="linear_issue_deletion", + tool_name="delete_linear_issue", + params={ + "issue_id": issue_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, + ) + + if result.rejected: + logger.info("Linear issue deletion rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_issue_id = result.params.get("issue_id", issue_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + logger.info( + f"Deleting Linear issue with final params: issue_id={final_issue_id}, " + f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) ) + connector = result.scalars().first() + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + logger.info(f"Validated Linear connector: id={actual_connector_id}") + else: + logger.error("No connector found for this issue") return { "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", + "message": "No connector found for this issue.", } - actual_connector_id = connector.id - logger.info(f"Validated Linear connector: id={actual_connector_id}") - else: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) + linear_client = LinearConnector( + session=db_session, connector_id=actual_connector_id + ) - result = await linear_client.archive_issue(issue_id=final_issue_id) + result = await linear_client.archive_issue(issue_id=final_issue_id) - logger.info( - f"archive_issue result: {result.get('status')} - {result.get('message', '')}" - ) + logger.info( + f"archive_issue result: {result.get('status')} - {result.get('message', '')}" + ) - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from app.db import Document + deleted_from_kb = False + if ( + result.get("status") == "success" + and final_delete_from_kb + and document_id + ): + try: + from app.db import Document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + result["warning"] = ( + f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" - ) - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if issue_identifier: - result["message"] = ( - f"Issue {issue_identifier} archived successfully." - ) - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} Also removed from the knowledge base." - ) + if result.get("status") == "success": + result["deleted_from_kb"] = deleted_from_kb + if issue_identifier: + result["message"] = ( + f"Issue {issue_identifier} archived successfully." + ) + if deleted_from_kb: + result["message"] = ( + f"{result.get('message', '')} Also removed from the knowledge base." + ) - return result + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py index f35d0dddd..d610ce2b7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearKBSyncService, LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_update_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the update_linear_issue tool. + """Factory function to create the update_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_update_linear_issue_tool( Returns: Configured update_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def update_linear_issue( @@ -86,7 +94,7 @@ def create_update_linear_issue_tool( """ logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -96,176 +104,177 @@ def create_update_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for update context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - team = context.get("team", {}) - new_state_id = _resolve_state(team, new_state_name) - new_assignee_id = _resolve_assignee(team, new_assignee_email) - new_label_ids = _resolve_labels(team, new_label_names) - - logger.info( - f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" - ) - result = request_approval( - action_type="linear_issue_update", - tool_name="update_linear_issue", - params={ - "issue_id": issue_id, - "document_id": document_id, - "new_title": new_title, - "new_description": new_description, - "new_state_id": new_state_id, - "new_assignee_id": new_assignee_id, - "new_priority": new_priority, - "new_label_ids": new_label_ids, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_document_id = result.params.get("document_id", document_id) - final_new_title = result.params.get("new_title", new_title) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_state_id = result.params.get("new_state_id", new_state_id) - final_new_assignee_id = result.params.get( - "new_assignee_id", new_assignee_id - ) - final_new_priority = result.params.get("new_priority", new_priority) - final_new_label_ids: list[str] | None = result.params.get( - "new_label_ids", new_label_ids - ) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - if not final_connector_id: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, issue_ref ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={final_connector_id}") - logger.info( - f"Updating Linear issue with final params: issue_id={final_issue_id}" - ) - linear_client = LinearConnector( - session=db_session, connector_id=final_connector_id - ) - updated_issue = await linear_client.update_issue( - issue_id=final_issue_id, - title=final_new_title, - description=final_new_description, - state_id=final_new_state_id, - assignee_id=final_new_assignee_id, - priority=final_new_priority, - label_ids=final_new_label_ids, - ) + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + logger.warning(f"Auth expired for update context: {error_msg}") + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "linear", + } + if "not found" in error_msg.lower(): + logger.warning(f"Issue not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + else: + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - if updated_issue.get("status") == "error": - logger.error( - f"Failed to update Linear issue: {updated_issue.get('message')}" - ) - return { - "status": "error", - "message": updated_issue.get("message"), - } + issue_id = context["issue"]["id"] + document_id = context["issue"]["document_id"] + connector_id_from_context = context.get("workspace", {}).get("id") - logger.info( - f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" - ) + team = context.get("team", {}) + new_state_id = _resolve_state(team, new_state_name) + new_assignee_id = _resolve_assignee(team, new_assignee_email) + new_label_ids = _resolve_labels(team, new_label_names) - if final_document_id is not None: logger.info( - f"Updating knowledge base for document {final_document_id}..." + f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" ) - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_id, - user_id=user_id, - search_space_id=search_space_id, + result = request_approval( + action_type="linear_issue_update", + tool_name="update_linear_issue", + params={ + "issue_id": issue_id, + "document_id": document_id, + "new_title": new_title, + "new_description": new_description, + "new_state_id": new_state_id, + "new_assignee_id": new_assignee_id, + "new_priority": new_priority, + "new_label_ids": new_label_ids, + "connector_id": connector_id_from_context, + }, + context=context, ) - if kb_result["status"] == "success": - logger.info( - f"Knowledge base successfully updated for issue {final_issue_id}" - ) - kb_message = " Your knowledge base has also been updated." - elif kb_result["status"] == "not_indexed": - kb_message = " This issue will be added to your knowledge base in the next scheduled sync." - else: - logger.warning( - f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" - ) - kb_message = " Your knowledge base will be updated in the next scheduled sync." - else: - kb_message = "" - identifier = updated_issue.get("identifier") - default_msg = f"Issue {identifier} updated successfully." - return { - "status": "success", - "identifier": identifier, - "url": updated_issue.get("url"), - "message": f"{updated_issue.get('message', default_msg)}{kb_message}", - } + if result.rejected: + logger.info("Linear issue update rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_issue_id = result.params.get("issue_id", issue_id) + final_document_id = result.params.get("document_id", document_id) + final_new_title = result.params.get("new_title", new_title) + final_new_description = result.params.get( + "new_description", new_description + ) + final_new_state_id = result.params.get("new_state_id", new_state_id) + final_new_assignee_id = result.params.get( + "new_assignee_id", new_assignee_id + ) + final_new_priority = result.params.get("new_priority", new_priority) + final_new_label_ids: list[str] | None = result.params.get( + "new_label_ids", new_label_ids + ) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + + if not final_connector_id: + logger.error("No connector found for this issue") + return { + "status": "error", + "message": "No connector found for this issue.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + logger.info(f"Validated Linear connector: id={final_connector_id}") + + logger.info( + f"Updating Linear issue with final params: issue_id={final_issue_id}" + ) + linear_client = LinearConnector( + session=db_session, connector_id=final_connector_id + ) + updated_issue = await linear_client.update_issue( + issue_id=final_issue_id, + title=final_new_title, + description=final_new_description, + state_id=final_new_state_id, + assignee_id=final_new_assignee_id, + priority=final_new_priority, + label_ids=final_new_label_ids, + ) + + if updated_issue.get("status") == "error": + logger.error( + f"Failed to update Linear issue: {updated_issue.get('message')}" + ) + return { + "status": "error", + "message": updated_issue.get("message"), + } + + logger.info( + f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" + ) + + if final_document_id is not None: + logger.info( + f"Updating knowledge base for document {final_document_id}..." + ) + kb_service = LinearKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + issue_id=final_issue_id, + user_id=user_id, + search_space_id=search_space_id, + ) + if kb_result["status"] == "success": + logger.info( + f"Knowledge base successfully updated for issue {final_issue_id}" + ) + kb_message = " Your knowledge base has also been updated." + elif kb_result["status"] == "not_indexed": + kb_message = " This issue will be added to your knowledge base in the next scheduled sync." + else: + logger.warning( + f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" + ) + kb_message = " Your knowledge base will be updated in the next scheduled sync." + else: + kb_message = "" + + identifier = updated_issue.get("identifier") + default_msg = f"Issue {identifier} updated successfully." + return { + "status": "success", + "identifier": identifier, + "url": updated_issue.get("url"), + "message": f"{updated_issue.get('message', default_msg)}{kb_message}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py index 0a24a988f..65c177d7a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers @@ -17,6 +18,23 @@ def create_create_luma_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_luma_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_luma_event tool + """ + del db_session # per-call session — see docstring + @tool async def create_luma_event( name: str, @@ -40,83 +58,86 @@ def create_create_luma_event_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - result = request_approval( - action_type="luma_create_event", - tool_name="create_luma_event", - params={ - "name": name, - "start_at": start_at, - "end_at": end_at, - "description": description, - "timezone": timezone, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Event was not created.", - } - - final_name = result.params.get("name", name) - final_start = result.params.get("start_at", start_at) - final_end = result.params.get("end_at", end_at) - final_desc = result.params.get("description", description) - final_tz = result.params.get("timezone", timezone) - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - body: dict[str, Any] = { - "name": final_name, - "start_at": final_start, - "end_at": final_end, - "timezone": final_tz, - } - if final_desc: - body["description_md"] = final_desc - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{LUMA_API}/event/create", - headers=headers, - json=body, + result = request_approval( + action_type="luma_create_event", + tool_name="create_luma_event", + params={ + "name": name, + "start_at": start_at, + "end_at": end_at, + "description": description, + "timezone": timezone, + }, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Luma Plus subscription required to create events via API.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Event was not created.", + } - data = resp.json() - event_id = data.get("api_id") or data.get("event", {}).get("api_id") + final_name = result.params.get("name", name) + final_start = result.params.get("start_at", start_at) + final_end = result.params.get("end_at", end_at) + final_desc = result.params.get("description", description) + final_tz = result.params.get("timezone", timezone) - return { - "status": "success", - "event_id": event_id, - "message": f"Event '{final_name}' created on Luma.", - } + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + body: dict[str, Any] = { + "name": final_name, + "start_at": final_start, + "end_at": final_end, + "timezone": final_tz, + } + if final_desc: + body["description_md"] = final_desc + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{LUMA_API}/event/create", + headers=headers, + json=body, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Luma Plus subscription required to create events via API.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", + } + + data = resp.json() + event_id = data.get("api_id") or data.get("event", {}).get("api_id") + + return { + "status": "success", + "event_id": event_id, + "message": f"Event '{final_name}' created on Luma.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py index aec5ad220..6885c2049 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_luma_events_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_luma_events tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_luma_events tool + """ + del db_session # per-call session — see docstring + @tool async def list_luma_events( max_results: int = 25, @@ -28,77 +47,80 @@ def create_list_luma_events_tool( Dictionary with status and a list of events including event_id, name, start_at, end_at, location, url. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} max_results = min(max_results, 50) try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - api_key = get_api_key(connector) - headers = luma_headers(api_key) + api_key = get_api_key(connector) + headers = luma_headers(api_key) - all_entries: list[dict] = [] - cursor = None + all_entries: list[dict] = [] + cursor = None - async with httpx.AsyncClient(timeout=20.0) as client: - while len(all_entries) < max_results: - params: dict[str, Any] = { - "limit": min(100, max_results - len(all_entries)) - } - if cursor: - params["cursor"] = cursor + async with httpx.AsyncClient(timeout=20.0) as client: + while len(all_entries) < max_results: + params: dict[str, Any] = { + "limit": min(100, max_results - len(all_entries)) + } + if cursor: + params["cursor"] = cursor - resp = await client.get( - f"{LUMA_API}/calendar/list-events", - headers=headers, - params=params, + resp = await client.get( + f"{LUMA_API}/calendar/list-events", + headers=headers, + params=params, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } + + data = resp.json() + entries = data.get("entries", []) + if not entries: + break + all_entries.extend(entries) + + next_cursor = data.get("next_cursor") + if not next_cursor: + break + cursor = next_cursor + + events = [] + for entry in all_entries[:max_results]: + ev = entry.get("event", {}) + geo = ev.get("geo_info", {}) + events.append( + { + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + } ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", - } - - data = resp.json() - entries = data.get("entries", []) - if not entries: - break - all_entries.extend(entries) - - next_cursor = data.get("next_cursor") - if not next_cursor: - break - cursor = next_cursor - - events = [] - for entry in all_entries[:max_results]: - ev = entry.get("event", {}) - geo = ev.get("geo_info", {}) - events.append( - { - "event_id": entry.get("api_id"), - "name": ev.get("name", "Untitled"), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location": geo.get("name", ""), - "url": ev.get("url", ""), - "visibility": ev.get("visibility", ""), - } - ) - - return {"status": "success", "events": events, "total": len(events)} + return {"status": "success", "events": events, "total": len(events)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py index b37a9d617..a8484e9c0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_luma_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_luma_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_luma_event tool + """ + del db_session # per-call session — see docstring + @tool async def read_luma_event(event_id: str) -> dict[str, Any]: """Read detailed information about a specific Luma event. @@ -26,60 +45,63 @@ def create_read_luma_event_tool( Dictionary with status and full event details including description, attendees count, meeting URL. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - async with httpx.AsyncClient(timeout=15.0) as client: - resp = await client.get( - f"{LUMA_API}/events/{event_id}", - headers=headers, + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 404: - return { - "status": "not_found", - "message": f"Event '{event_id}' not found.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + f"{LUMA_API}/events/{event_id}", + headers=headers, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code == 404: + return { + "status": "not_found", + "message": f"Event '{event_id}' not found.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } + + data = resp.json() + ev = data.get("event", data) + geo = ev.get("geo_info", {}) + + event_detail = { + "event_id": event_id, + "name": ev.get("name", ""), + "description": ev.get("description", ""), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location_name": geo.get("name", ""), + "address": geo.get("address", ""), + "url": ev.get("url", ""), + "meeting_url": ev.get("meeting_url", ""), + "visibility": ev.get("visibility", ""), + "cover_url": ev.get("cover_url", ""), } - data = resp.json() - ev = data.get("event", data) - geo = ev.get("geo_info", {}) - - event_detail = { - "event_id": event_id, - "name": ev.get("name", ""), - "description": ev.get("description", ""), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location_name": geo.get("name", ""), - "address": geo.get("address", ""), - "url": ev.get("url", ""), - "meeting_url": ev.get("meeting_url", ""), - "visibility": ev.get("visibility", ""), - "cover_url": ev.get("cover_url", ""), - } - - return {"status": "success", "event": event_detail} + return {"status": "success", "event": event_detail} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py index 6efffe960..6ec95e9f0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,17 @@ def create_create_notion_page_tool( """ Factory function to create the create_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Per-call sessions also + keep the request's outer transaction free of long-running Notion API + blocking. + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +39,7 @@ def create_create_notion_page_tool( Returns: Configured create_notion_page tool """ + del db_session # per-call session — see docstring @tool async def create_notion_page( @@ -67,7 +78,7 @@ def create_create_notion_page_tool( """ logger.info(f"create_notion_page called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -77,154 +88,157 @@ def create_create_notion_page_tool( } try: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return { - "status": "error", - "message": context["error"], - } - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Notion accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "notion", - } - - logger.info(f"Requesting approval for creating Notion page: '{title}'") - result = request_approval( - action_type="notion_page_creation", - tool_name="create_notion_page", - params={ - "title": title, - "content": content, - "parent_page_id": None, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) - final_parent_page_id = result.params.get("parent_page_id") - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return { - "status": "error", - "message": "Page title cannot be empty. Please provide a valid title.", - } - - logger.info( - f"Creating Notion page with final params: title='{final_title}'" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - logger.warning( - f"No Notion connector found for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "No Notion connector found. Please connect Notion in your workspace settings.", - } - - actual_connector_id = connector.id - logger.info(f"Found Notion connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: + if "error" in context: logger.error( - f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" + f"Failed to fetch creation context: {context['error']}" ) return { "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + "message": context["error"], } - logger.info(f"Validated Notion connector: id={actual_connector_id}") - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Notion accounts have expired authentication") + return { + "status": "auth_error", + "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "notion", + } - result = await notion_connector.create_page( - title=final_title, - content=final_content, - parent_page_id=final_parent_page_id, - ) - logger.info( - f"create_page result: {result.get('status')} - {result.get('message', '')}" - ) + logger.info(f"Requesting approval for creating Notion page: '{title}'") + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={ + "title": title, + "content": content, + "parent_page_id": None, + "connector_id": connector_id, + }, + context=context, + ) - if result.get("status") == "success": - kb_message_suffix = "" - try: - from app.services.notion import NotionKBSyncService + if result.rejected: + logger.info("Notion page creation rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=result.get("page_id"), - page_title=result.get("title", final_title), - page_url=result.get("url"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) + final_parent_page_id = result.params.get("parent_page_id") + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_title or not final_title.strip(): + logger.error("Title is empty or contains only whitespace") + return { + "status": "error", + "message": "Page title cannot be empty. Please provide a valid title.", + } + + logger.info( + f"Creating Notion page with final params: title='{final_title}'" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, ) - else: + ) + connector = result.scalars().first() + + if not connector: + logger.warning( + f"No Notion connector found for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "No Notion connector found. Please connect Notion in your workspace settings.", + } + + actual_connector_id = connector.id + logger.info(f"Found Notion connector: id={actual_connector_id}") + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + logger.info(f"Validated Notion connector: id={actual_connector_id}") + + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + result = await notion_connector.create_page( + title=final_title, + content=final_content, + parent_page_id=final_parent_page_id, + ) + logger.info( + f"create_page result: {result.get('status')} - {result.get('message', '')}" + ) + + if result.get("status") == "success": + kb_message_suffix = "" + try: + from app.services.notion import NotionKBSyncService + + kb_service = NotionKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + page_id=result.get("page_id"), + page_title=result.get("title", final_title), + page_url=result.get("url"), + content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - result["message"] = result.get("message", "") + kb_message_suffix + result["message"] = result.get("message", "") + kb_message_suffix - return result + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py index 07f7583d2..7b85da4c2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion.tool_metadata_service import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def create_delete_notion_page_tool( """ Factory function to create the delete_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for finding the correct Notion connector connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_delete_notion_page_tool( Returns: Configured delete_notion_page tool """ + del db_session # per-call session — see docstring @tool async def delete_notion_page( @@ -63,7 +71,7 @@ def create_delete_notion_page_tool( f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -73,164 +81,167 @@ def create_delete_notion_page_tool( } try: - # Get page context (page_id, account, title) from indexed data - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + # Get page context (page_id, account, title) from indexed data + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_delete_context( + search_space_id, user_id, page_title ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - page_id = context.get("page_id") - connector_id_from_context = account.get("id") - document_id = context.get("document_id") - - logger.info( - f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" - ) - - result = request_approval( - action_type="notion_page_deletion", - tool_name="delete_notion_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - # Validate the connector - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - # Create connector instance - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - # Delete the page from Notion - result = await notion_connector.delete_page(page_id=final_page_id) - logger.info( - f"delete_page result: {result.get('status')} - {result.get('message', '')}" - ) - - # If deletion was successful and user wants to delete from KB - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from sqlalchemy.future import select - - from app.db import Document - - # Get the document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) + if "error" in context: + error_msg = context["error"] + # Check if it's a "not found" error (softer handling for LLM) + if "not found" in error_msg.lower(): + logger.warning(f"Page not found: {error_msg}") + return { + "status": "not_found", + "message": error_msg, + } else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" - ) + logger.error(f"Failed to fetch delete context: {error_msg}") + return { + "status": "error", + "message": error_msg, + } - # Update result with KB deletion status - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} (also removed from knowledge base)" + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Notion account %s has expired authentication", + account.get("id"), ) + return { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + } - return result + page_id = context.get("page_id") + connector_id_from_context = account.get("id") + document_id = context.get("document_id") + + logger.info( + f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" + ) + + result = request_approval( + action_type="notion_page_deletion", + tool_name="delete_notion_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, + ) + + if result.rejected: + logger.info("Notion page deletion rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + logger.info( + f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + # Validate the connector + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + actual_connector_id = connector.id + logger.info(f"Validated Notion connector: id={actual_connector_id}") + else: + logger.error("No connector found for this page") + return { + "status": "error", + "message": "No connector found for this page.", + } + + # Create connector instance + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + # Delete the page from Notion + result = await notion_connector.delete_page(page_id=final_page_id) + logger.info( + f"delete_page result: {result.get('status')} - {result.get('message', '')}" + ) + + # If deletion was successful and user wants to delete from KB + deleted_from_kb = False + if ( + result.get("status") == "success" + and final_delete_from_kb + and document_id + ): + try: + from sqlalchemy.future import select + + from app.db import Document + + # Get the document + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + result["warning"] = ( + f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" + ) + + # Update result with KB deletion status + if result.get("status") == "success": + result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + result["message"] = ( + f"{result.get('message', '')} (also removed from knowledge base)" + ) + + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py index 85c08177c..df757476a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def create_update_notion_page_tool( """ Factory function to create the update_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache (see + ``create_create_notion_page_tool`` for the full rationale). + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_update_notion_page_tool( Returns: Configured update_notion_page tool """ + del db_session # per-call session — see docstring @tool async def update_notion_page( @@ -71,7 +79,7 @@ def create_update_notion_page_tool( f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -88,152 +96,155 @@ def create_update_notion_page_tool( } try: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - - page_id = context.get("page_id") - document_id = context.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - logger.info( - f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" - ) - result = request_approval( - action_type="notion_page_update", - tool_name="update_notion_page", - params={ - "page_id": page_id, - "content": content, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_content = result.params.get("content", content) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - logger.info( - f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - result = await notion_connector.update_page( - page_id=final_page_id, - content=final_content, - ) - logger.info( - f"update_page result: {result.get('status')} - {result.get('message', '')}" - ) - - if result.get("status") == "success" and document_id is not None: - from app.services.notion import NotionKBSyncService - - logger.info(f"Updating knowledge base for document {document_id}...") - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - appended_content=final_content, - user_id=user_id, - search_space_id=search_space_id, - appended_block_ids=result.get("appended_block_ids"), + async with async_session_maker() as db_session: + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, page_title ) - if kb_result["status"] == "success": - result["message"] = ( - f"{result['message']}. Your knowledge base has also been updated." - ) - logger.info( - f"Knowledge base successfully updated for page {final_page_id}" - ) - elif kb_result["status"] == "not_indexed": - result["message"] = ( - f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." - ) - else: - result["message"] = ( - f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." - ) + if "error" in context: + error_msg = context["error"] + # Check if it's a "not found" error (softer handling for LLM) + if "not found" in error_msg.lower(): + logger.warning(f"Page not found: {error_msg}") + return { + "status": "not_found", + "message": error_msg, + } + else: + logger.error(f"Failed to fetch update context: {error_msg}") + return { + "status": "error", + "message": error_msg, + } + + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"KB update failed for page {final_page_id}: {kb_result['message']}" + "Notion account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + } + + page_id = context.get("page_id") + document_id = context.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") + + logger.info( + f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" + ) + result = request_approval( + action_type="notion_page_update", + tool_name="update_notion_page", + params={ + "page_id": page_id, + "content": content, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: + logger.info("Notion page update rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_page_id = result.params.get("page_id", page_id) + final_content = result.params.get("content", content) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + + logger.info( + f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + actual_connector_id = connector.id + logger.info(f"Validated Notion connector: id={actual_connector_id}") + else: + logger.error("No connector found for this page") + return { + "status": "error", + "message": "No connector found for this page.", + } + + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + result = await notion_connector.update_page( + page_id=final_page_id, + content=final_content, + ) + logger.info( + f"update_page result: {result.get('status')} - {result.get('message', '')}" + ) + + if result.get("status") == "success" and document_id is not None: + from app.services.notion import NotionKBSyncService + + logger.info( + f"Updating knowledge base for document {document_id}..." + ) + kb_service = NotionKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=document_id, + appended_content=final_content, + user_id=user_id, + search_space_id=search_space_id, + appended_block_ids=result.get("appended_block_ids"), ) - return result + if kb_result["status"] == "success": + result["message"] = ( + f"{result['message']}. Your knowledge base has also been updated." + ) + logger.info( + f"Knowledge base successfully updated for page {final_page_id}" + ) + elif kb_result["status"] == "not_indexed": + result["message"] = ( + f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." + ) + else: + result["message"] = ( + f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." + ) + logger.warning( + f"KB update failed for page {final_page_id}: {kb_result['message']}" + ) + + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index 21272e01d..5f199a41b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -10,7 +10,7 @@ from sqlalchemy.future import select from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -48,6 +48,23 @@ def create_create_onedrive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_onedrive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_onedrive_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_onedrive_file( name: str, @@ -70,173 +87,178 @@ def create_create_onedrive_file_tool( """ logger.info(f"create_onedrive_file called: name='{name}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "OneDrive tool not properly configured.", } try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected OneDrive accounts need re-authentication.", - "connector_type": "onedrive", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = OneDriveClient(session=db_session, connector_id=cid) - items, err = await client.list_children("root") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - {"folder_id": item["id"], "name": item["name"]} - for item in items - if item.get("folder") is not None - and item.get("id") - and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", cid, exc_info=True - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - } - - result = request_approval( - action_type="onedrive_file_creation", - tool_name="create_onedrive_file", - params={ - "name": name, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_docx_extension(final_name) - - if final_connector_id is not None: + async with async_session_maker() as db_session: result = await db_session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) - connector = result.scalars().first() - else: - connector = connectors[0] + connectors = result.scalars().all() - if not connector: - return { - "status": "error", - "message": "Selected OneDrive connector is invalid.", + if not connectors: + return { + "status": "error", + "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", + } + + accounts = [] + for c in connectors: + cfg = c.config or {} + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) + + if all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected OneDrive accounts need re-authentication.", + "connector_type": "onedrive", + } + + parent_folders: dict[int, list[dict[str, str]]] = {} + for acc in accounts: + cid = acc["id"] + if acc.get("auth_expired"): + parent_folders[cid] = [] + continue + try: + client = OneDriveClient(session=db_session, connector_id=cid) + items, err = await client.list_children("root") + if err: + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) + parent_folders[cid] = [] + else: + parent_folders[cid] = [ + {"folder_id": item["id"], "name": item["name"]} + for item in items + if item.get("folder") is not None + and item.get("id") + and item.get("name") + ] + except Exception: + logger.warning( + "Error fetching folders for connector %s", + cid, + exc_info=True, + ) + parent_folders[cid] = [] + + context: dict[str, Any] = { + "accounts": accounts, + "parent_folders": parent_folders, } - docx_bytes = _markdown_to_docx(final_content or "") - - client = OneDriveClient(session=db_session, connector_id=connector.id) - created = await client.create_file( - name=final_name, - parent_id=final_parent_folder_id, - content=docx_bytes, - mime_type=DOCX_MIME, - ) - - logger.info( - f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.onedrive import OneDriveKBSyncService - - kb_service = OneDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=DOCX_MIME, - web_url=created.get("webUrl"), - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, + result = request_approval( + action_type="onedrive_file_creation", + tool_name="create_onedrive_file", + params={ + "name": name, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_url": created.get("webUrl"), - "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_name = result.params.get("name", name) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") + + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + final_name = _ensure_docx_extension(final_name) + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + connector = result.scalars().first() + else: + connector = connectors[0] + + if not connector: + return { + "status": "error", + "message": "Selected OneDrive connector is invalid.", + } + + docx_bytes = _markdown_to_docx(final_content or "") + + client = OneDriveClient(session=db_session, connector_id=connector.id) + created = await client.create_file( + name=final_name, + parent_id=final_parent_folder_id, + content=docx_bytes, + mime_type=DOCX_MIME, + ) + + logger.info( + f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" + ) + + kb_message_suffix = "" + try: + from app.services.onedrive import OneDriveKBSyncService + + kb_service = OneDriveKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=created.get("id"), + file_name=created.get("name", final_name), + mime_type=DOCX_MIME, + web_url=created.get("webUrl"), + content=final_content, + connector_id=connector.id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": created.get("id"), + "name": created.get("name"), + "web_url": created.get("webUrl"), + "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index a7f13b5df..4857ea988 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -13,6 +13,7 @@ from app.db import ( DocumentType, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) logger = logging.getLogger(__name__) @@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_onedrive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_onedrive_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_onedrive_file( file_name: str, @@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool( f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "OneDrive tool not properly configured.", } try: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: + async with async_session_maker() as db_session: doc_result = await db_session.execute( select(Document) .join( @@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool( and_( Document.search_space_id == search_space_id, Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower( - cast( - Document.document_metadata["onedrive_file_name"], - String, - ) - ) - == func.lower(file_name), + func.lower(Document.title) == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool( ) document = doc_result.scalars().first() - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed OneDrive files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_id = meta.get("onedrive_file_id") - document_id = document.id - - if not file_id: - return { - "status": "error", - "message": "File ID is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + if not document: + doc_result = await db_session.execute( + select(Document) + .join( + SearchSourceConnector, + Document.connector_id == SearchSourceConnector.id, + ) + .filter( + and_( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.ONEDRIVE_FILE, + func.lower( + cast( + Document.document_metadata[ + "onedrive_file_name" + ], + String, + ) + ) + == func.lower(file_name), + SearchSourceConnector.user_id == user_id, + ) + ) + .order_by(Document.updated_at.desc().nullslast()) + .limit(1) ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "OneDrive connector not found or access denied.", - } + document = doc_result.scalars().first() - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "onedrive", - } + if not document: + return { + "status": "not_found", + "message": ( + f"File '{file_name}' not found in your indexed OneDrive files. " + "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " + "or (3) the file name is different." + ), + } - context = { - "file": { - "file_id": file_id, - "name": file_name, - "document_id": document_id, - "web_url": meta.get("web_url"), - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } + if not document.connector_id: + return { + "status": "error", + "message": "Document has no associated connector.", + } - result = request_approval( - action_type="onedrive_file_trash", - tool_name="delete_onedrive_file", - params={ - "file_id": file_id, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + meta = document.document_metadata or {} + file_id = meta.get("onedrive_file_id") + document_id = document.id - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if not file_id: + return { + "status": "error", + "message": "File ID is missing. Please re-index the file.", + } - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if final_connector_id != connector.id: - result = await db_session.execute( + conn_result = await db_session.execute( select(SearchSourceConnector).filter( and_( - SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type @@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool( ) ) ) - validated_connector = result.scalars().first() - if not validated_connector: + connector = conn_result.scalars().first() + if not connector: return { "status": "error", - "message": "Selected OneDrive connector is invalid or has been disconnected.", + "message": "OneDrive connector not found or access denied.", } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - logger.info( - f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" - ) + cfg = connector.config or {} + if cfg.get("auth_expired"): + return { + "status": "auth_error", + "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "onedrive", + } - client = OneDriveClient( - session=db_session, connector_id=actual_connector_id - ) - await client.trash_file(final_file_id) + context = { + "file": { + "file_id": file_id, + "name": file_name, + "document_id": document_id, + "web_url": meta.get("web_url"), + }, + "account": { + "id": connector.id, + "name": connector.name, + "user_email": cfg.get("user_email"), + }, + } - logger.info( - f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" - ) - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file_name}' to the recycle bin.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + result = request_approval( + action_type="onedrive_file_trash", + tool_name="delete_onedrive_file", + params={ + "file_id": file_id, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if final_connector_id != connector.id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + and_( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id + == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + ) + validated_connector = result.scalars().first() + if not validated_connector: + return { + "status": "error", + "message": "Selected OneDrive connector is invalid or has been disconnected.", + } + actual_connector_id = validated_connector.id + else: + actual_connector_id = connector.id + + logger.info( + f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" + ) + + client = OneDriveClient( + session=db_session, connector_id=actual_connector_id + ) + await client.trash_file(final_file_id) + + logger.info( + f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" + ) + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": final_file_id, + "message": f"Successfully moved '{file_name}' to the recycle bin.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + doc = doc_result.scalars().first() + if doc: + await db_session.delete(doc) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index e8bab36fd..b842d7a20 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -824,13 +824,22 @@ async def build_tools_async( """Async version of build_tools that also loads MCP tools from database. Design Note: - This function exists because MCP tools require database queries to load user configs, - while built-in tools are created synchronously from static code. + This function exists because MCP tools require database queries to load + user configs, while built-in tools are created synchronously from static + code. - Alternative: We could make build_tools() itself async and always query the database, - but that would force async everywhere even when only using built-in tools. The current - design keeps the simple case (static tools only) synchronous while supporting dynamic - database-loaded tools through this async wrapper. + Alternative: We could make build_tools() itself async and always query + the database, but that would force async everywhere even when only using + built-in tools. The current design keeps the simple case (static tools + only) synchronous while supporting dynamic database-loaded tools through + this async wrapper. + + Phase 1.3: built-in tool construction (CPU; runs in a thread pool to + avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on + the event loop) are kicked off concurrently. Cold-path savings are + bounded by the slower of the two — typically MCP at ~200ms-1.7s — + so the parallelization recovers the ~50-200ms previously spent + serially on built-in construction. Args: dependencies: Dict containing all possible dependencies @@ -843,33 +852,70 @@ async def build_tools_async( List of configured tool instances ready for the agent, including MCP tools. """ + import asyncio import time _perf_log = logging.getLogger("surfsense.perf") _perf_log.setLevel(logging.DEBUG) + can_load_mcp = ( + include_mcp_tools + and "db_session" in dependencies + and "search_space_id" in dependencies + ) + + # Built-in tool construction is synchronous + CPU-only. Off-loop it so + # MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure + # function over its inputs — safe to thread-shift. _t0 = time.perf_counter() - tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + builtin_task = asyncio.create_task( + asyncio.to_thread( + build_tools, dependencies, enabled_tools, disabled_tools, additional_tools + ) + ) + + mcp_task: asyncio.Task | None = None + if can_load_mcp: + mcp_task = asyncio.create_task( + load_mcp_tools( + dependencies["db_session"], + dependencies["search_space_id"], + ) + ) + + # Surface failures from each task independently so a flaky MCP + # endpoint never poisons built-in tool registration. ``return_exceptions`` + # gives us per-task exceptions instead of dropping the second result + # when the first raises. + if mcp_task is not None: + builtin_result, mcp_result = await asyncio.gather( + builtin_task, mcp_task, return_exceptions=True + ) + else: + builtin_result = await builtin_task + mcp_result = None + + if isinstance(builtin_result, BaseException): + raise builtin_result # built-in registration failure is non-recoverable + tools: list[BaseTool] = builtin_result _perf_log.info( - "[build_tools_async] Built-in tools in %.3fs (%d tools)", + "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)", time.perf_counter() - _t0, len(tools), ) - # Load MCP tools if requested and dependencies are available - if ( - include_mcp_tools - and "db_session" in dependencies - and "search_space_id" in dependencies - ): - try: - _t0 = time.perf_counter() - mcp_tools = await load_mcp_tools( - dependencies["db_session"], - dependencies["search_space_id"], + if mcp_task is not None: + if isinstance(mcp_result, BaseException): + # ``return_exceptions=True`` captures the exception out-of-band, + # so ``sys.exc_info()`` is empty here. Pass the captured + # exception via ``exc_info=`` to get a real traceback. + logging.error( + "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result ) + else: + mcp_tools = mcp_result or [] _perf_log.info( - "[build_tools_async] MCP tools loaded in %.3fs (%d tools)", + "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)", time.perf_counter() - _t0, len(mcp_tools), ) @@ -879,8 +925,6 @@ async def build_tools_async( len(mcp_tools), [t.name for t in mcp_tools], ) - except Exception as e: - logging.exception("Failed to load MCP tools: %s", e) logging.info( "Total tools for agent: %d — %s", diff --git a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py index b8b1527c7..2965f2f02 100644 --- a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py +++ b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py @@ -15,7 +15,7 @@ from langchain_core.tools import tool from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument +from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker from app.utils.document_converters import embed_text @@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession): """ Factory function to create the search_surfsense_docs tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + Args: - db_session: Database session for executing queries + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. Returns: A configured tool function for searching Surfsense documentation """ + del db_session # per-call session — see docstring @tool async def search_surfsense_docs(query: str, top_k: int = 10) -> str: @@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession): Returns: Relevant documentation content formatted with chunk IDs for citations """ - return await search_surfsense_docs_async( - query=query, - db_session=db_session, - top_k=top_k, - ) + async with async_session_maker() as db_session: + return await search_surfsense_docs_async( + query=query, + db_session=db_session, + top_k=top_k, + ) return search_surfsense_docs diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py index d7b000853..0fc52b5c7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import GRAPH_API, get_access_token, get_teams_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_teams_channels_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_teams_channels tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_teams_channels tool + """ + del db_session # per-call session — see docstring + @tool async def list_teams_channels() -> dict[str, Any]: """List all Microsoft Teams and their channels the user has access to. @@ -23,63 +42,66 @@ def create_list_teams_channels_tool( Dictionary with status and a list of teams, each containing team_id, team_name, and a list of channels (id, name). """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - headers = {"Authorization": f"Bearer {token}"} - - async with httpx.AsyncClient(timeout=20.0) as client: - teams_resp = await client.get( - f"{GRAPH_API}/me/joinedTeams", headers=headers + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - if teams_resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if teams_resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {teams_resp.status_code}", - } + token = await get_access_token(db_session, connector) + headers = {"Authorization": f"Bearer {token}"} - teams_data = teams_resp.json().get("value", []) - result_teams = [] - - async with httpx.AsyncClient(timeout=20.0) as client: - for team in teams_data: - team_id = team["id"] - ch_resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels", - headers=headers, - ) - channels = [] - if ch_resp.status_code == 200: - channels = [ - {"id": ch["id"], "name": ch.get("displayName", "")} - for ch in ch_resp.json().get("value", []) - ] - result_teams.append( - { - "team_id": team_id, - "team_name": team.get("displayName", ""), - "channels": channels, - } + async with httpx.AsyncClient(timeout=20.0) as client: + teams_resp = await client.get( + f"{GRAPH_API}/me/joinedTeams", headers=headers ) - return { - "status": "success", - "teams": result_teams, - "total_teams": len(result_teams), - } + if teams_resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if teams_resp.status_code != 200: + return { + "status": "error", + "message": f"Graph API error: {teams_resp.status_code}", + } + + teams_data = teams_resp.json().get("value", []) + result_teams = [] + + async with httpx.AsyncClient(timeout=20.0) as client: + for team in teams_data: + team_id = team["id"] + ch_resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels", + headers=headers, + ) + channels = [] + if ch_resp.status_code == 200: + channels = [ + {"id": ch["id"], "name": ch.get("displayName", "")} + for ch in ch_resp.json().get("value", []) + ] + result_teams.append( + { + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + } + ) + + return { + "status": "success", + "teams": result_teams, + "total_teams": len(result_teams), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py index d24a7e4d3..0ebda021e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import GRAPH_API, get_access_token, get_teams_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_teams_messages_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_teams_messages tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_teams_messages tool + """ + del db_session # per-call session — see docstring + @tool async def read_teams_messages( team_id: str, @@ -32,65 +51,68 @@ def create_read_teams_messages_tool( Dictionary with status and a list of messages including id, sender, content, timestamp. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} limit = min(limit, 50) try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", - headers={"Authorization": f"Bearer {token}"}, - params={"$top": limit}, + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Insufficient permissions to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {resp.status_code}", - } + token = await get_access_token(db_session, connector) - raw_msgs = resp.json().get("value", []) - messages = [] - for m in raw_msgs: - sender = m.get("from", {}) - user_info = sender.get("user", {}) if sender else {} - body = m.get("body", {}) - messages.append( - { - "id": m.get("id"), - "sender": user_info.get("displayName", "Unknown"), - "content": body.get("content", ""), - "content_type": body.get("contentType", "text"), - "timestamp": m.get("createdDateTime", ""), + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", + headers={"Authorization": f"Bearer {token}"}, + params={"$top": limit}, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Insufficient permissions to read this channel.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Graph API error: {resp.status_code}", } - ) - return { - "status": "success", - "team_id": team_id, - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } + raw_msgs = resp.json().get("value", []) + messages = [] + for m in raw_msgs: + sender = m.get("from", {}) + user_info = sender.get("user", {}) if sender else {} + body = m.get("body", {}) + messages.append( + { + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + } + ) + + return { + "status": "success", + "team_id": team_id, + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py index fd8d00870..6f40d27e1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import GRAPH_API, get_access_token, get_teams_connector @@ -17,6 +18,23 @@ def create_send_teams_message_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_teams_message tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_teams_message tool + """ + del db_session # per-call session — see docstring + @tool async def send_teams_message( team_id: str, @@ -39,70 +57,73 @@ def create_send_teams_message_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - result = request_approval( - action_type="teams_send_message", - tool_name="send_teams_message", - params={ - "team_id": team_id, - "channel_id": channel_id, - "content": content, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_team = result.params.get("team_id", team_id) - final_channel = result.params.get("channel_id", channel_id) - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", + result = request_approval( + action_type="teams_send_message", + tool_name="send_teams_message", + params={ + "team_id": team_id, + "channel_id": channel_id, + "content": content, }, - json={"body": {"content": final_content}}, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "insufficient_permissions", - "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": "Message sent to Teams channel.", - } + final_content = result.params.get("content", content) + final_team = result.params.get("team_id", team_id) + final_channel = result.params.get("channel_id", channel_id) + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"body": {"content": final_content}}, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if resp.status_code == 403: + return { + "status": "insufficient_permissions", + "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", + } + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": "Message sent to Teams channel.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py index ceaddb80f..42148967c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py +++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py @@ -26,7 +26,7 @@ from langchain_core.tools import tool from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SearchSpace, User +from app.db import SearchSpace, User, async_session_maker from app.utils.content_utils import extract_text_content logger = logging.getLogger(__name__) @@ -302,6 +302,25 @@ def create_update_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): + """Factory function to create the user-memory update tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + The session's bound ``commit``/``rollback`` methods are captured at + call time, after ``async with`` has bound ``db_session`` locally. + + Args: + user_id: ID of the user whose memory document is being updated. + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + llm: Optional LLM for the forced-rewrite path. + + Returns: + Configured update_memory tool for the user-memory scope. + """ + del db_session # per-call session — see docstring uid = UUID(user_id) if isinstance(user_id, str) else user_id @tool @@ -318,26 +337,26 @@ def create_update_memory_tool( updated_memory: The FULL updated markdown document (not a diff). """ try: - result = await db_session.execute(select(User).where(User.id == uid)) - user = result.scalars().first() - if not user: - return {"status": "error", "message": "User not found."} + async with async_session_maker() as db_session: + result = await db_session.execute(select(User).where(User.id == uid)) + user = result.scalars().first() + if not user: + return {"status": "error", "message": "User not found."} - old_memory = user.memory_md + old_memory = user.memory_md - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, - llm=llm, - apply_fn=lambda content: setattr(user, "memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="memory", - scope="user", - ) + return await _save_memory( + updated_memory=updated_memory, + old_memory=old_memory, + llm=llm, + apply_fn=lambda content: setattr(user, "memory_md", content), + commit_fn=db_session.commit, + rollback_fn=db_session.rollback, + label="memory", + scope="user", + ) except Exception as e: logger.exception("Failed to update user memory: %s", e) - await db_session.rollback() return { "status": "error", "message": f"Failed to update memory: {e}", @@ -351,6 +370,27 @@ def create_update_team_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): + """Factory function to create the team-memory update tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + The session's bound ``commit``/``rollback`` methods are captured at + call time, after ``async with`` has bound ``db_session`` locally. + + Args: + search_space_id: ID of the search space whose team memory is being + updated. + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + llm: Optional LLM for the forced-rewrite path. + + Returns: + Configured update_memory tool for the team-memory scope. + """ + del db_session # per-call session — see docstring + @tool async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the team's shared memory document for this search space. @@ -366,28 +406,30 @@ def create_update_team_memory_tool( updated_memory: The FULL updated markdown document (not a diff). """ try: - result = await db_session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - space = result.scalars().first() - if not space: - return {"status": "error", "message": "Search space not found."} + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + space = result.scalars().first() + if not space: + return {"status": "error", "message": "Search space not found."} - old_memory = space.shared_memory_md + old_memory = space.shared_memory_md - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, - llm=llm, - apply_fn=lambda content: setattr(space, "shared_memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="team memory", - scope="team", - ) + return await _save_memory( + updated_memory=updated_memory, + old_memory=old_memory, + llm=llm, + apply_fn=lambda content: setattr( + space, "shared_memory_md", content + ), + commit_fn=db_session.commit, + rollback_fn=db_session.rollback, + label="team memory", + scope="team", + ) except Exception as e: logger.exception("Failed to update team memory: %s", e) - await db_session.rollback() return { "status": "error", "message": f"Failed to update team memory: {e}", diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 016c2de42..08194e7fb 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -31,6 +31,7 @@ from app.config import ( initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session @@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None: OpenRouterIntegrationService.get_instance().stop_background_refresh() +async def _warm_agent_jit_caches() -> None: + """Pay the LangChain / LangGraph / Deepagents JIT cost at startup. + + Why + ---- + A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema + generation chain takes 1.5-2 seconds of pure CPU on first invocation + inside any Python process: the graph compiler builds reducers, + Pydantic v2 generates and JITs validator schemas, deepagents + eagerly compiles its general-purpose subagent, etc. Subsequent + compiles in the same process pay only ~50% of that cost (the lazy + JIT bits are cached in module-level dicts). + + Doing one throwaway compile during ``lifespan`` startup pre-pays + that cost so the *first real request* doesn't. We do NOT prime + :mod:`agent_cache` because the cache key requires real + ``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the + throwaway agent is genuinely thrown away and immediately collected. + + Safety + ------ + * No DB access. We construct a stub LLM (no real keys), pass an + empty tools list, and pass ``checkpointer=None`` so we never + touch Postgres. + * Bounded by ``asyncio.wait_for`` so a hang here can never block + worker startup. On any failure, we log + swallow — the worst + case is the first real request pays the full cold cost (i.e. + pre-warmup behaviour). + """ + import time as _time + + logger = logging.getLogger(__name__) + t0 = _time.perf_counter() + try: + from langchain.agents import create_agent + from langchain.agents.middleware import ( + ModelCallLimitMiddleware, + TodoListMiddleware, + ToolCallLimitMiddleware, + ) + from langchain_core.language_models.fake_chat_models import ( + FakeListChatModel, + ) + from langchain_core.tools import tool + + from app.agents.new_chat.context import SurfSenseContextSchema + + # Minimal LLM stub. ``FakeListChatModel`` satisfies + # ``BaseChatModel`` without any network or auth — perfect for + # exercising the compile path without side effects. + stub_llm = FakeListChatModel(responses=["warmup-response"]) + + # Two trivial tools with arg + return schemas — exercises the + # Pydantic v2 schema JIT path. Without at least one tool the + # graph compile skips the tool-loop bytecode generation that + # accounts for ~30-50% of cold compile cost. + @tool + def _warmup_tool_a(query: str, limit: int = 5) -> str: + """Warmup tool A — never actually invoked.""" + return query[:limit] + + @tool + def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]: + """Warmup tool B — never actually invoked.""" + return {"name": name, "value": value} + + # A handful of common middleware so the compile pre-pays the + # ``AgentMiddleware`` resolver path. These instances never run + # because the throwaway agent is immediately collected. + # ``SubAgentMiddleware`` is the single heaviest line in cold + # ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to + # compile its general-purpose subagent's full inner graph), + # so we include it here to make sure that compile path is JIT'd. + warmup_middleware: list = [ + TodoListMiddleware(), + ModelCallLimitMiddleware( + thread_limit=120, run_limit=80, exit_behavior="end" + ), + ToolCallLimitMiddleware( + thread_limit=300, run_limit=80, exit_behavior="continue" + ), + ] + try: + from deepagents import SubAgentMiddleware + from deepagents.backends import StateBackend + from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT + + gp_warmup_spec = { # type: ignore[var-annotated] + **GENERAL_PURPOSE_SUBAGENT, + "model": stub_llm, + "tools": [_warmup_tool_a], + "middleware": [TodoListMiddleware()], + } + warmup_middleware.append( + SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec]) + ) + except Exception: + # Deepagents missing/incompatible — middleware-only warmup + # still produces a useful (smaller) speedup. + logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True) + + compiled = create_agent( + stub_llm, + tools=[_warmup_tool_a, _warmup_tool_b], + system_prompt="You are a warmup stub.", + middleware=warmup_middleware, + context_schema=SurfSenseContextSchema, + checkpointer=None, + ) + + # Touch the compiled graph's stream_channels / nodes so any + # remaining lazy schema work fires now instead of on first + # real invocation. + _ = list(getattr(compiled, "nodes", {}).keys()) + + del compiled + logger.info( + "[startup] Agent JIT warmup completed in %.3fs", + _time.perf_counter() - t0, + ) + except Exception: + logger.warning( + "[startup] Agent JIT warmup failed in %.3fs (non-fatal — first " + "real request will pay the full compile cost)", + _time.perf_counter() - t0, + exc_info=True, + ) + + @asynccontextmanager async def lifespan(app: FastAPI): # Tune GC: lower gen-2 threshold so long-lived garbage is collected @@ -432,6 +562,7 @@ async def lifespan(app: FastAPI): await setup_checkpointer_tables() initialize_openrouter_integration() _start_openrouter_background_refresh() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() @@ -443,6 +574,18 @@ async def lifespan(app: FastAPI): "Docs will be indexed on the next restart." ) + # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays + # worker readiness. ``shield`` so Uvicorn cancelling startup + # doesn't leave half-warmed Pydantic schemas in an inconsistent + # state. + try: + await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20) + except (TimeoutError, Exception): # pragma: no cover - defensive + logging.getLogger(__name__).warning( + "[startup] Agent JIT warmup hit timeout/error — skipping; " + "first real request will pay the full compile cost." + ) + log_system_snapshot("startup_complete") yield @@ -452,6 +595,23 @@ async def lifespan(app: FastAPI): def registration_allowed(): + """Master auth kill switch keyed on the REGISTRATION_ENABLED env var. + + Despite the name, this dependency does NOT only gate registration. When + REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface + that could mint or refresh a session for an attacker: + + * email/password ``POST /auth/register`` + * email/password ``POST /auth/jwt/login`` + * the Google OAuth router (``/auth/google/authorize`` and the shared + ``/auth/google/callback`` handles both new signups and login for + existing users, so flipping this off locks both) + * the bespoke ``/auth/google/authorize-redirect`` helper used by the UI + + Use it as a temporary "freeze all new sessions" lever during incident + response. It is not a way to disable signup while keeping login working; + for that, override ``UserManager.oauth_callback`` instead. + """ if not config.REGISTRATION_ENABLED: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled" @@ -596,32 +756,45 @@ app.add_middleware( allow_headers=["*"], # Allows all headers ) -app.include_router( - fastapi_users.get_auth_router(auth_backend), - prefix="/auth/jwt", - tags=["auth"], - dependencies=[Depends(rate_limit_login)], -) -app.include_router( - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], - dependencies=[ - Depends(rate_limit_register), - Depends(registration_allowed), # blocks registration when disabled - ], -) -app.include_router( - fastapi_users.get_reset_password_router(), - prefix="/auth", - tags=["auth"], - dependencies=[Depends(rate_limit_password_reset)], -) -app.include_router( - fastapi_users.get_verify_router(UserRead), - prefix="/auth", - tags=["auth"], -) +# Password / email-based auth routers are only mounted when not running in +# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left +# POST /auth/register reachable, which is the bypass that allowed bots to +# create non-OAuth users in spite of AUTH_TYPE=GOOGLE. +if config.AUTH_TYPE != "GOOGLE": + app.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth/jwt", + tags=["auth"], + dependencies=[ + Depends(rate_limit_login), + Depends( + registration_allowed + ), # honour REGISTRATION_ENABLED kill switch on login too + ], + ) + app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + dependencies=[ + Depends(rate_limit_register), + Depends(registration_allowed), + ], + ) + app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], + dependencies=[Depends(rate_limit_password_reset)], + ) + app.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], + ) + +# /users/me (read/update profile) is needed in every auth mode, so it stays +# mounted unconditionally. app.include_router( fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", @@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE": ), prefix="/auth/google", tags=["auth"], - dependencies=[ - Depends(registration_allowed) - ], # blocks OAuth registration when disabled + # REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE + # it blocks BOTH new OAuth signups AND login of existing OAuth users + # (the fastapi-users OAuth router shares one callback for create+login, + # so this dependency closes both paths together). + dependencies=[Depends(registration_allowed)], ) # Add a redirect-based authorize endpoint for Firefox/Safari compatibility # This endpoint performs a server-side redirect instead of returning JSON # which fixes cross-site cookie issues where browsers don't send cookies - # set via cross-origin fetch requests on subsequent redirects - @app.get("/auth/google/authorize-redirect", tags=["auth"]) + # set via cross-origin fetch requests on subsequent redirects. + # The registration_allowed dependency mirrors the OAuth router above so + # the kill switch fails fast here instead of bouncing users to Google + # only to 403 on the callback. + @app.get( + "/auth/google/authorize-redirect", + tags=["auth"], + dependencies=[Depends(registration_allowed)], + ) async def google_authorize_redirect( request: Request, ): diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 58a8b0f39..74710d5e1 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -22,10 +22,12 @@ def init_worker(**kwargs): initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) initialize_openrouter_integration() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 675b05d2c..97b4cf509 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -47,11 +47,37 @@ def load_global_llm_configs(): data = yaml.safe_load(f) configs = data.get("global_llm_configs", []) + # Lazy import keeps the `app.config` -> `app.services` edge one-way + # and matches the `provider_api_base` pattern used elsewhere. + from app.services.provider_capabilities import derive_supports_image_input + seen_slugs: dict[str, int] = {} for cfg in configs: cfg.setdefault("billing_tier", "free") cfg.setdefault("anonymous_enabled", False) cfg.setdefault("seo_enabled", False) + # Capability flag: explicit YAML override always wins. When the + # operator has not annotated the model, defer to LiteLLM's + # authoritative model map (`supports_vision`) which already + # knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are + # vision-capable. Unknown / unmapped models default-allow so + # we don't lock the user out of a freshly added third-party + # entry; the streaming-task safety net (driven by + # `is_known_text_only_chat_model`) is the only place a False + # actually blocks a request. + if "supports_image_input" not in cfg: + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + cfg["supports_image_input"] = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) if cfg.get("seo_enabled") and cfg.get("seo_slug"): slug = cfg["seo_slug"] @@ -138,7 +164,11 @@ def load_global_image_gen_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_image_generation_configs", []) + configs = data.get("global_image_generation_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global image generation configs: {e}") return [] @@ -153,7 +183,11 @@ def load_global_vision_llm_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_vision_llm_configs", []) + configs = data.get("global_vision_llm_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global vision LLM configs: {e}") return [] @@ -254,6 +288,15 @@ def load_openrouter_integration_settings() -> dict | None: "anonymous_enabled_free", settings["anonymous_enabled"] ) + # Image generation + vision LLM emission are opt-in (issue L). + # OpenRouter's catalogue contains hundreds of image / vision + # capable models; auto-injecting all of them into every + # deployment would explode the model selector and surprise + # operators upgrading from prior versions. Default to False so + # admins must explicitly turn them on. + settings.setdefault("image_generation_enabled", False) + settings.setdefault("vision_enabled", False) + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") @@ -296,10 +339,60 @@ def initialize_openrouter_integration(): ) else: print("Info: OpenRouter integration enabled but no models fetched") + + # Image generation + vision LLM emissions are opt-in (issue L). + # Both reuse the catalogue already cached by ``service.initialize`` + # so we don't make additional network calls here. + if settings.get("image_generation_enabled"): + try: + image_configs = service.get_image_generation_configs() + if image_configs: + config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs) + print( + f"Info: OpenRouter integration added {len(image_configs)} " + f"image-generation models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") + + if settings.get("vision_enabled"): + try: + vision_configs = service.get_vision_llm_configs() + if vision_configs: + config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) + print( + f"Info: OpenRouter integration added {len(vision_configs)} " + f"vision LLM models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") +def initialize_pricing_registration(): + """ + Teach LiteLLM the per-token cost of every deployment in + ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled + from the OpenRouter catalogue + any operator-declared YAML pricing). + + Must run AFTER ``initialize_openrouter_integration()`` so the + OpenRouter catalogue is populated and BEFORE the first LLM call so + ``response_cost`` is available in ``TokenTrackingCallback``. + + Failures are logged but never raised — startup must not be blocked + by a missing pricing entry; the worst-case is the model debits 0. + """ + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as e: + print(f"Warning: Failed to register LiteLLM pricing: {e}") + + def initialize_llm_router(): """ Initialize the LLM Router service for Auto mode. @@ -444,14 +537,54 @@ class Config: os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") ) - # Premium token quota settings - PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000")) + # Premium credit (micro-USD) quota settings. + # + # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy + # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are + # still honoured for one release as fall-back values — the prior + # $1-per-1M-tokens Stripe price means every existing value maps 1:1 + # to micros, so operators upgrading without changing their .env still + # get correct behaviour. A startup deprecation warning fires below if + # they're set. + PREMIUM_CREDIT_MICROS_LIMIT = int( + os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") + or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000") + ) STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") - STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")) + STRIPE_CREDIT_MICROS_PER_UNIT = int( + os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT") + or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000") + ) STRIPE_TOKEN_BUYING_ENABLED = ( os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" ) + # Safety ceiling on the per-call premium reservation. ``stream_new_chat`` + # estimates an upper-bound cost from ``litellm.get_model_info`` x the + # config's ``quota_reserve_tokens`` and clamps the result to this value + # so a misconfigured "$1000/M" model can't lock the user's whole balance + # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K + # reserve_tokens ≈ $0.36) with headroom. + QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000")) + + if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv( + "PREMIUM_CREDIT_MICROS_LIMIT" + ): + print( + "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to " + "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the " + "current Stripe price). The old key will be removed in a " + "future release." + ) + if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv( + "STRIPE_CREDIT_MICROS_PER_UNIT" + ): + print( + "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to " + "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). " + "The old key will be removed in a future release." + ) + # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000")) @@ -464,6 +597,35 @@ class Config: # Default quota reserve tokens when not specified per-model QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000")) + # Per-image reservation (in micro-USD) used by ``billable_call`` for the + # ``POST /image-generations`` endpoint when the global config does not + # override it. $0.05 covers realistic worst-cases for current OpenAI / + # OpenRouter image-gen pricing. Bypassed entirely for free configs. + QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") + ) + + # Per-podcast reservation (in micro-USD). One agent LLM call generating + # a transcript, typically 5k-20k completion tokens. $0.20 covers a long + # premium-model run. Tune via env. + QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000") + ) + + # Per-video-presentation reservation (in micro-USD). Fan-out of N + # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``) + # plus refine retries; can produce many premium completions. $1.00 + # covers worst-case. Tune via env. + # + # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of + # 1_000_000. The override path in ``billable_call`` bypasses the + # per-call clamp in ``estimate_call_reserve_micros``, so this is the + # *actual* hold — raising it via env is fine but means a single video + # task can lock $1+ of credit. + QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000") + ) + # Abuse prevention: concurrent stream cap and CAPTCHA ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2")) ANON_CAPTCHA_REQUEST_THRESHOLD = int( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 79cbe1e51..d92640c8d 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -19,6 +19,24 @@ # Structure matches NewLLMConfig: # - Model configuration (provider, model_name, api_key, etc.) # - Prompt configuration (system_instructions, citations_enabled) +# +# COST-BASED PREMIUM CREDITS: +# Each premium config bills the user's USD-credit balance based on the +# actual provider cost reported by LiteLLM. For models LiteLLM already +# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything. +# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment) +# or any model LiteLLM doesn't have in its built-in pricing table, declare +# per-token costs inline so they bill correctly: +# +# litellm_params: +# base_model: "my-custom-azure-deploy" +# # USD per token; e.g. 0.000003 == $3.00 per million input tokens +# input_cost_per_token: 0.000003 +# output_cost_per_token: 0.000015 +# +# OpenRouter dynamic models pull pricing automatically from OpenRouter's +# API — no inline declaration needed. Models without resolvable pricing +# debit $0 from the user's balance and log a WARNING. # Router Settings for Auto Mode # These settings control how the LiteLLM Router distributes requests across models @@ -292,6 +310,17 @@ openrouter_integration: free_rpm: 20 free_tpm: 100000 + # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue + # contains hundreds of image- and vision-capable models; turning these on + # injects them into the global Image-Generation / Vision-LLM model + # selectors alongside any static configs. Tier (free/premium) is derived + # per model the same way it is for chat (`:free` suffix or zero pricing). + # When a user picks a premium image/vision model the call debits the + # shared $5 USD-cost-based premium credit pool — so leaving these off + # avoids surprise quota burn on existing deployments. Default: false. + image_generation_enabled: false + vision_enabled: false + litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 2fe478d9b..aef959ec9 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin): prompt_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0) + cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0") model_breakdown = Column(JSONB, nullable=True) call_details = Column(JSONB, nullable=True) @@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin): class PremiumTokenPurchase(Base, TimestampMixin): - """Tracks Stripe checkout sessions used to grant additional premium token credits.""" + """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units). + + Note: the table name is preserved (``premium_token_purchases``) for + operational continuity even though the unit is now USD micro-credits + instead of raw tokens. The ``credit_micros_granted`` column replaced + the legacy ``tokens_granted`` in migration 140; the stored values + were not transformed because the prior $1 = 1M tokens Stripe price + makes the unit conversion 1:1 numerically. + """ __tablename__ = "premium_token_purchases" __allow_unmapped__ = True @@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin): ) stripe_payment_intent_id = Column(String(255), nullable=True, index=True) quantity = Column(Integer, nullable=False) - tokens_granted = Column(BigInteger, nullable=False) + credit_micros_granted = Column(BigInteger, nullable=False) amount_total = Column(Integer, nullable=True) currency = Column(String(10), nullable=True) status = Column( @@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE": ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) @@ -2241,16 +2250,16 @@ else: ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py index 4bb38b7b0..d45bd780c 100644 --- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py +++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py @@ -68,12 +68,25 @@ class EtlPipelineService: etl_service="VISION_LLM", content_type="image", ) - except Exception: - logging.warning( - "Vision LLM failed for %s, falling back to document parser", - request.filename, - exc_info=True, - ) + except Exception as exc: + # Special-case quota exhaustion so we log a clearer message + # — the vision LLM didn't "fail", the user just ran out of + # premium credit. Falling through to the document parser + # is a graceful degradation: OCR/Unstructured still + # extracts text from the image without burning credit. + from app.services.billable_calls import QuotaInsufficientError + + if isinstance(exc, QuotaInsufficientError): + logging.info( + "Vision LLM quota exhausted for %s; falling back to document parser", + request.filename, + ) + else: + logging.warning( + "Vision LLM failed for %s, falling back to document parser", + request.filename, + exc_info=True, + ) else: logging.info( "No vision LLM provided, falling back to document parser for %s", diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index 5732a8dfb..99388af66 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags +from app.config import config from app.db import User from app.users import current_active_user @@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel): enable_otel: bool + enable_desktop_local_filesystem: bool + @classmethod def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead: # asdict() avoids missing-field bugs when AgentFeatureFlags grows. - return cls(**asdict(flags)) + return cls( + **asdict(flags), + enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM, + ) @router.get("/agent/flags", response_model=AgentFeatureFlagsRead) diff --git a/surfsense_backend/app/routes/composio_routes.py b/surfsense_backend/app/routes/composio_routes.py index 4bf360365..7bc2addf8 100644 --- a/surfsense_backend/app/routes/composio_routes.py +++ b/surfsense_backend/app/routes/composio_routes.py @@ -649,13 +649,9 @@ async def list_composio_drive_folders( """ List folders AND files in user's Google Drive via Composio. - Uses the same GoogleDriveClient / list_folder_contents path as the native - connector, with Composio-sourced credentials. This means auth errors - propagate identically (Google returns 401 → exception → auth_expired flag). + Uses Composio's Google Drive tool execution path so managed OAuth tokens + do not need to be exposed through connected account state. """ - from app.connectors.google_drive import GoogleDriveClient, list_folder_contents - from app.utils.google_credentials import build_composio_credentials - if not ComposioService.is_enabled(): raise HTTPException( status_code=503, @@ -689,10 +685,37 @@ async def list_composio_drive_folders( detail="Composio connected account not found. Please reconnect the connector.", ) - credentials = build_composio_credentials(composio_connected_account_id) - drive_client = GoogleDriveClient(session, connector_id, credentials=credentials) + service = ComposioService() + entity_id = f"surfsense_{user.id}" + items = [] + page_token = None + error = None - items, error = await list_folder_contents(drive_client, parent_id=parent_id) + while True: + page_items, next_token, page_error = await service.get_drive_files( + connected_account_id=composio_connected_account_id, + entity_id=entity_id, + folder_id=parent_id, + page_token=page_token, + page_size=100, + ) + if page_error: + error = page_error + break + + items.extend(page_items) + if not next_token: + break + page_token = next_token + + for item in items: + item["isFolder"] = ( + item.get("mimeType") == "application/vnd.google-apps.folder" + ) + + items.sort( + key=lambda item: (not item["isFolder"], item.get("name", "").lower()) + ) if error: error_lower = error.lower() diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 97a3559b9..018234ad5 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -36,11 +36,17 @@ from app.schemas import ( ImageGenerationListRead, ImageGenerationRead, ) +from app.services.billable_calls import ( + DEFAULT_IMAGE_RESERVE_MICROS, + QuotaInsufficientError, + billable_call, +) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + """Resolve the LiteLLM provider prefix used in model strings.""" + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: """Build a litellm model string from provider + model_name.""" - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" + + +async def _resolve_billing_for_image_gen( + session: AsyncSession, + config_id: int | None, + search_space: SearchSpace, +) -> tuple[str, str, int]: + """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request. + + The resolution mirrors ``_execute_image_generation``'s lookup tree but + only extracts the fields needed for billing — we do this *before* + ``billable_call`` so the reservation is correctly sized for the + config that will actually run, and so we don't open an + ``ImageGeneration`` row for a request that's about to 402. + + User-owned (positive ID) BYOK configs are always free — they cost + the user nothing on our side. Auto mode currently treats as free + because the underlying router can dispatch to either premium or + free YAML configs and we don't surface the resolved deployment up + here yet. Bringing Auto under premium billing would require + threading the chosen deployment back from ``ImageGenRouterService``. + """ + resolved_id = config_id + if resolved_id is None: + resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + + if is_image_gen_auto_mode(resolved_id): + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + + if resolved_id < 0: + cfg = _get_global_image_gen_config(resolved_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + base_model = _build_model_string( + cfg.get("provider", ""), + cfg.get("model_name", ""), + cfg.get("custom_provider"), + ) + reserve_micros = int( + cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + ) + return (billing_tier, base_model, reserve_micros) + + # Positive ID = user-owned BYOK image-gen config — always free. + return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) async def _execute_image_generation( @@ -138,12 +192,18 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - model_string = _build_model_string( - cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -165,12 +225,18 @@ async def _execute_image_generation( if not db_cfg: raise ValueError(f"Image generation config {config_id} not found") - model_string = _build_model_string( - db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: @@ -225,10 +291,15 @@ async def get_global_image_gen_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode currently treated as free until per-deployment + # billing-tier surfacing lands (see _resolve_billing_for_image_gen). + "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -241,6 +312,12 @@ async def get_global_image_gen_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_micros": cfg.get("quota_reserve_micros"), } ) @@ -454,7 +531,26 @@ async def create_image_generation( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create and execute an image generation request.""" + """Create and execute an image generation request. + + Premium configs are gated by the user's shared premium credit pool. + The flow is: + + 1. Permission check + load the search space (cheap, no provider call). + 2. Resolve which config will run so we know its billing tier and the + worst-case reservation size *before* opening any DB rows. + 3. Wrap the entire ImageGeneration row insert + provider call in + ``billable_call``. If quota is denied, ``billable_call`` raises + ``QuotaInsufficientError`` *before* we flush a row, which we + translate to HTTP 402 (no orphaned rows on the user's account, + no inserted error rows for "you ran out of credit"). + 4. On success, the actual ``response_cost`` flows through the + LiteLLM callback into the accumulator, and ``billable_call`` + finalizes the debit at exit. Inner ``try/except`` still catches + provider errors and stores them on ``error_message`` (HTTP 200 + with ``error_message`` set is preserved for failed-but-not-quota + scenarios — clients already know how to surface those). + """ try: await check_permission( session, @@ -471,33 +567,70 @@ async def create_image_generation( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - db_image_gen = ImageGeneration( - prompt=data.prompt, - model=data.model, - n=data.n, - quality=data.quality, - size=data.size, - style=data.style, - response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, - search_space_id=data.search_space_id, - created_by_id=user.id, + billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( + session, data.image_generation_config_id, search_space ) - session.add(db_image_gen) - await session.flush() - try: - await _execute_image_generation(session, db_image_gen, search_space) - except Exception as e: - logger.exception("Image generation call failed") - db_image_gen.error_message = str(e) + # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError + # propagates to the outer ``except QuotaInsufficientError`` handler + # below as HTTP 402 — it is intentionally NOT swallowed into + # ``error_message`` because that would (1) imply a successful row + # exists when none does, and (2) return HTTP 200 to a client + # whose request was actively *denied* (issue K). + async with billable_call( + user_id=search_space.user_id, + search_space_id=data.search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=reserve_micros, + usage_type="image_generation", + call_details={"model": base_model, "prompt": data.prompt[:100]}, + ): + db_image_gen = ImageGeneration( + prompt=data.prompt, + model=data.model, + n=data.n, + quality=data.quality, + size=data.size, + style=data.style, + response_format=data.response_format, + image_generation_config_id=data.image_generation_config_id, + search_space_id=data.search_space_id, + created_by_id=user.id, + ) + session.add(db_image_gen) + await session.flush() - await session.commit() - await session.refresh(db_image_gen) - return db_image_gen + try: + await _execute_image_generation(session, db_image_gen, search_space) + except Exception as e: + logger.exception("Image generation call failed") + db_image_gen.error_message = str(e) + + await session.commit() + await session.refresh(db_image_gen) + return db_image_gen except HTTPException: raise + except QuotaInsufficientError as exc: + # The user's premium credit pool is empty. No DB row is created + # because ``billable_call`` denies before yielding (issue K). + await session.rollback() + raise HTTPException( + status_code=402, + detail={ + "error_code": "premium_quota_exhausted", + "usage_type": exc.usage_type, + "used_micros": exc.used_micros, + "limit_micros": exc.limit_micros, + "remaining_micros": exc.remaining_micros, + "message": ( + "Out of premium credits for image generation. " + "Purchase additional credits or switch to a free model." + ), + }, + ) from exc except SQLAlchemyError: await session.rollback() raise HTTPException( diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 28b197ca2..d3bd51129 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1366,7 +1366,11 @@ async def append_message( # flush assigns the PK/defaults without a round-trip SELECT await session.flush() - # Persist token usage if provided (for assistant messages) + # Persist token usage if provided (for assistant messages). + # ``cost_micros`` is the provider USD cost reported by LiteLLM, + # forwarded by the FE through the appendMessage round-trip so + # the historical TokenUsage row matches the credit debit applied + # at finalize time. token_usage_data = raw_body.get("token_usage") if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: await record_token_usage( @@ -1377,6 +1381,7 @@ async def append_message( prompt_tokens=token_usage_data.get("prompt_tokens", 0), completion_tokens=token_usage_data.get("completion_tokens", 0), total_tokens=token_usage_data.get("total_tokens", 0), + cost_micros=token_usage_data.get("cost_micros", 0), model_breakdown=token_usage_data.get("usage"), call_details=token_usage_data.get("call_details"), thread_id=thread_id, diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 20779a309..e090a1a7c 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -29,6 +29,7 @@ from app.schemas import ( NewLLMConfigUpdate, ) from app.services.llm_service import validate_llm_config +from app.services.provider_capabilities import derive_supports_image_input from app.users import current_active_user from app.utils.rbac import check_permission @@ -36,6 +37,39 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: + """Augment a BYOK chat config row with the derived ``supports_image_input``. + + There is no DB column for ``supports_image_input`` — the value is + resolved at the API boundary from LiteLLM's authoritative model map + (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps + the response shape consistent across list / detail / create / update + endpoints without having to remember to set the field at every call + site. + """ + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + supports_image_input = derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ) + # ``model_validate`` runs the Pydantic conversion using the ORM + # attribute access path enabled by ``ConfigDict(from_attributes=True)``, + # then we layer the derived field on. ``model_copy(update=...)`` keeps + # the surface immutable from the caller's perspective. + base_read = NewLLMConfigRead.model_validate(config) + return base_read.model_copy(update={"supports_image_input": supports_image_input}) + + # ============================================================================= # Global Configs Routes # ============================================================================= @@ -84,11 +118,41 @@ async def get_global_new_llm_configs( "seo_title": None, "seo_description": None, "quota_reserve_tokens": None, + # Auto routes across the configured pool, which usually + # includes at least one vision-capable deployment, so + # treat Auto as image-capable. The router itself will + # still pick a vision-capable deployment for messages + # carrying image_url blocks (LiteLLM Router falls back + # on ``404`` per its ``allowed_fails`` policy). + "supports_image_input": True, } ) # Add individual global configs for cfg in global_configs: + # Capability resolution: explicit value (YAML override or OR + # `_supports_image_input(model)` payload baked in by the + # OpenRouter integration service) wins. Fall back to the + # LiteLLM-driven helper which default-allows on unknown so + # we don't hide vision-capable models that happen to lack a + # YAML annotation. The streaming task safety net is the + # only place a False ever blocks. + if "supports_image_input" in cfg: + supports_image_input = bool(cfg.get("supports_image_input")) + else: + cfg_litellm_params = cfg.get("litellm_params") or {} + cfg_base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + supports_image_input = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=cfg_base_model, + custom_provider=cfg.get("custom_provider"), + ) + safe_config = { "id": cfg.get("id"), "name": cfg.get("name"), @@ -113,6 +177,7 @@ async def get_global_new_llm_configs( "seo_title": cfg.get("seo_title"), "seo_description": cfg.get("seo_description"), "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "supports_image_input": supports_image_input, } safe_configs.append(safe_config) @@ -171,7 +236,7 @@ async def create_new_llm_config( await session.commit() await session.refresh(db_config) - return db_config + return _serialize_byok_config(db_config) except HTTPException: raise @@ -213,7 +278,7 @@ async def list_new_llm_configs( .limit(limit) ) - return result.scalars().all() + return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] except HTTPException: raise @@ -268,7 +333,7 @@ async def get_new_llm_config( "You don't have permission to view LLM configurations in this search space", ) - return config + return _serialize_byok_config(config) except HTTPException: raise @@ -360,7 +425,7 @@ async def update_new_llm_config( await session.commit() await session.refresh(config) - return config + return _serialize_byok_config(config) except HTTPException: raise diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index e44455ad3..0f0e43035 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -591,6 +591,7 @@ async def _get_image_gen_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -607,6 +608,7 @@ async def _get_image_gen_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None @@ -649,6 +651,7 @@ async def _get_vision_llm_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -665,6 +668,7 @@ async def _get_vision_llm_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index cfdd4b52a..aed74ec8d 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase( metadata = _get_metadata(checkout_session) user_id = metadata.get("user_id") quantity = int(metadata.get("quantity", "0")) - tokens_per_unit = int(metadata.get("tokens_per_unit", "0")) + # Read the new metadata key first, fall back to the legacy one so + # in-flight checkout sessions created before the cost-credits + # release still fulfil correctly (the unit is numerically the + # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens). + credit_micros_per_unit = int( + metadata.get("credit_micros_per_unit") + or metadata.get("tokens_per_unit", "0") + ) - if not user_id or quantity <= 0 or tokens_per_unit <= 0: + if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: logger.error( "Skipping token fulfillment for session %s: incomplete metadata %s", checkout_session_id, @@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase( getattr(checkout_session, "payment_intent", None) ), quantity=quantity, - tokens_granted=quantity * tokens_per_unit, + credit_micros_granted=quantity * credit_micros_per_unit, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase( purchase.stripe_payment_intent_id = _normalize_optional_string( getattr(checkout_session, "payment_intent", None) ) - user.premium_tokens_limit = ( - max(user.premium_tokens_used, user.premium_tokens_limit) - + purchase.tokens_granted + # Top up the user's credit balance by the granted micro-USD amount. + # ``max(used, limit)`` clamps the case where the legacy code wrote a + # used value above the limit (e.g. underbilling rounding) so adding + # ``credit_micros_granted`` always lifts the limit by the full pack + # size rather than disappearing into past overuse. + user.premium_credit_micros_limit = ( + max(user.premium_credit_micros_used, user.premium_credit_micros_limit) + + purchase.credit_micros_granted ) await db_session.commit() @@ -532,12 +544,18 @@ async def create_token_checkout_session( user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), ): - """Create a Stripe Checkout Session for buying premium token packs.""" + """Create a Stripe Checkout Session for buying premium credit packs. + + Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of + credit (default 1_000_000 = $1.00). The user's balance is debited + at the actual provider cost reported by LiteLLM at finalize time, + so $1 of credit always buys $1 worth of provider usage at cost. + """ _ensure_token_buying_enabled() stripe_client = get_stripe_client() price_id = _get_required_token_price_id() success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) - tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT + credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( @@ -556,8 +574,8 @@ async def create_token_checkout_session( "metadata": { "user_id": str(user.id), "quantity": str(body.quantity), - "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT), - "purchase_type": "premium_tokens", + "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), + "purchase_type": "premium_credit", }, } ) @@ -583,7 +601,7 @@ async def create_token_checkout_session( getattr(checkout_session, "payment_intent", None) ), quantity=body.quantity, - tokens_granted=tokens_granted, + credit_micros_granted=credit_micros_granted, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -598,14 +616,19 @@ async def create_token_checkout_session( async def get_token_status( user: User = Depends(current_active_user), ): - """Return token-buying availability and current premium quota for frontend.""" - used = user.premium_tokens_used - limit = user.premium_tokens_limit + """Return token-buying availability and current premium credit quota for frontend. + + Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M + when displaying. The route name is preserved for back-compat with + pinned client deployments. + """ + used = user.premium_credit_micros_used + limit = user.premium_credit_micros_limit return TokenStripeStatusResponse( token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, - premium_tokens_used=used, - premium_tokens_limit=limit, - premium_tokens_remaining=max(0, limit - used), + premium_credit_micros_used=used, + premium_credit_micros_limit=limit, + premium_credit_micros_remaining=max(0, limit - used), ) diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 315c7c9fe..e4f08f604 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -82,10 +82,15 @@ async def get_global_vision_llm_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode treated as free until per-deployment billing-tier + # surfacing lands; see ``get_vision_llm`` for parity. + "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -98,6 +103,14 @@ async def get_global_vision_llm_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "input_cost_per_token": cfg.get("input_cost_per_token"), + "output_cost_per_token": cfg.get("output_cost_per_token"), } ) diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 69f534e20..4262b2b3f 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel): Schema for reading global image generation configs from YAML. Global configs have negative IDs. API key is hidden. ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. """ id: int = Field( @@ -231,3 +237,24 @@ class GlobalImageGenConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_micros: int | None = Field( + default=None, + description=( + "Optional override for the reservation amount (in micro-USD) used when " + "this image generation is premium. Falls back to " + "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." + ), + ) diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index ec5eefc07..892ff9693 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel): prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 + cost_micros: int = 0 model_breakdown: dict | None = None model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 9cc1fce58..e64478d38 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (no DB column). Default + # True matches the conservative-allow stance — a BYOK row that the + # route forgot to augment is not pre-judged. The streaming-task + # safety net is the only place a False actually blocks a request. + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map " + "(``litellm.supports_vision``) — there is no DB column. " + "Default True is the conservative-allow stance for unknown / " + "unmapped models." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (see NewLLMConfigRead). + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map. " + "Default True is the conservative-allow stance." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel): seo_title: str | None = None seo_description: str | None = None quota_reserve_tokens: int | None = None + supports_image_input: bool = Field( + default=True, + description=( + "Whether the model accepts image inputs (multimodal vision). " + "Derived server-side: OpenRouter dynamic configs use " + "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " + "authoritative model map (``litellm.supports_vision``). The " + "new-chat selector hints with a 'No image' badge when this is " + "False and there are pending image attachments. The streaming " + "task fails fast only when LiteLLM *explicitly* marks a model " + "as text-only — unknown / unmapped models default-allow." + ), + ) # ============================================================================= diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py index 3edd3e9e4..57265ec8e 100644 --- a/surfsense_backend/app/schemas/stripe.py +++ b/surfsense_backend/app/schemas/stripe.py @@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel): class TokenPurchaseRead(BaseModel): - """Serialized premium token purchase record.""" + """Serialized premium credit purchase record. + + ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The + schema name kept ``Token`` for API back-compat with pinned clients. + """ id: uuid.UUID stripe_checkout_session_id: str stripe_payment_intent_id: str | None = None quantity: int - tokens_granted: int + credit_micros_granted: int amount_total: int | None = None currency: str | None = None status: str @@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel): class TokenPurchaseHistoryResponse(BaseModel): - """Response containing the user's premium token purchases.""" + """Response containing the user's premium credit purchases.""" purchases: list[TokenPurchaseRead] class TokenStripeStatusResponse(BaseModel): - """Response describing token-buying availability and current quota.""" + """Response describing premium-credit-buying availability and balance. + + All ``premium_credit_micros_*`` fields are in micro-USD; the FE + divides by 1_000_000 to display USD. + """ token_buying_enabled: bool - premium_tokens_used: int = 0 - premium_tokens_limit: int = 0 - premium_tokens_remaining: int = 0 + premium_credit_micros_used: int = 0 + premium_credit_micros_limit: int = 0 + premium_credit_micros_remaining: int = 0 diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py index ab2e609dc..d0eeaf5c6 100644 --- a/surfsense_backend/app/schemas/vision_llm.py +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel): class GlobalVisionLLMConfigRead(BaseModel): + """Schema for reading global vision LLM configs from YAML. + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + id: int = Field(...) name: str description: str | None = None @@ -73,3 +82,35 @@ class GlobalVisionLLMConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_tokens: int | None = Field( + default=None, + description=( + "Optional override for the per-call reservation in *tokens* — " + "converted to micro-USD via the model's input/output prices at " + "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." + ), + ) + input_cost_per_token: float | None = Field( + default=None, + description=( + "Optional input price in USD/token. Used by pricing_registration to " + "register custom Azure / OpenRouter aliases with LiteLLM at startup." + ), + ) + output_cost_per_token: float | None = Field( + default=None, + description="Optional output price in USD/token. Pair with input_cost_per_token.", + ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 3a2c681b7..9bbca8669 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None: _healthy_until.pop(int(config_id), None) -def _global_candidates() -> list[dict]: +def _cfg_supports_image_input(cfg: dict) -> bool: + """True if the global cfg can accept image inputs. + + Prefers the explicit ``supports_image_input`` flag (set by the YAML + loader / OpenRouter integration). Falls back to a LiteLLM lookup so + a YAML entry whose flag was somehow stripped doesn't get wrongly + excluded. Default-allows on unknown — the streaming-task safety net + is the actual block, not this filter. + """ + if "supports_image_input" in cfg: + return bool(cfg.get("supports_image_input")) + # Lazy import: provider_capabilities -> llm_config -> services chain; + # importing at module load would create an init-order cycle through + # ``app.config``. + from app.services.provider_capabilities import derive_supports_image_input + + cfg_litellm_params = cfg.get("litellm_params") or {} + base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + return derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + + +def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers can't be picked as the thread's pin. Also excludes configs currently in runtime cooldown (e.g. temporary 429 bursts). + + When ``requires_image_input`` is True (image turn), additionally + filters out configs whose ``supports_image_input`` resolves to False + so a text-only deployment can't be pinned for an image request. """ candidates = [ cfg @@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]: if _is_usable_global_config(cfg) and not cfg.get("health_gated") and not _is_runtime_cooled_down(int(cfg.get("id", 0))) + and (not requires_image_input or _cfg_supports_image_input(cfg)) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -185,6 +220,15 @@ def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() +def _is_preferred_premium_auto_config(cfg: dict) -> bool: + """Return True for the operator-preferred premium Auto model.""" + return ( + _tier_of(cfg) == "premium" + and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + and str(cfg.get("model_name", "")).lower() == "gpt-5.4" + ) + + def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: """Pick a config with quality-first ranking + deterministic spread. @@ -237,11 +281,20 @@ async def resolve_or_get_pinned_llm_config_id( selected_llm_config_id: int, force_repin_free: bool = False, exclude_config_ids: set[int] | None = None, + requires_image_input: bool = False, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. + + When ``requires_image_input`` is True (the current turn carries an + ``image_url`` block), the candidate pool is filtered to vision-capable + cfgs and any existing pin that can't accept image input is treated as + invalid (force re-pin). If no vision-capable cfg is available the + function raises ``ValueError`` so the streaming task surfaces the same + friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of + silently routing the image to a text-only deployment. """ thread = ( ( @@ -274,14 +327,24 @@ async def resolve_or_get_pinned_llm_config_id( excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} candidates = [ - c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + c + for c in _global_candidates(requires_image_input=requires_image_input) + if int(c.get("id", 0)) not in excluded_ids ] if not candidates: + if requires_image_input: + # Distinguish the "no vision-capable cfg" case from generic + # "no usable cfg" so the streaming task can map this to the + # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. + raise ValueError( + "No vision-capable global LLM configs are available for Auto mode" + ) raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent - # tier switch), unless the caller explicitly requests a forced repin to free. + # tier switch), unless the caller explicitly requests a forced repin to free + # *or* the turn requires image input but the pin can't handle it. pinned_id = thread.pinned_llm_config_id if ( not force_repin_free @@ -311,6 +374,29 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=True, ) if pinned_id is not None: + # If the pin is *only* invalid because it can't handle the image + # turn (it's still a healthy, usable config in the broader pool), + # log that explicitly so operators can correlate the re-pin with + # the user's image attachment instead of suspecting a cooldown. + if requires_image_input: + try: + pinned_global = next( + c + for c in config.GLOBAL_LLM_CONFIGS + if int(c.get("id", 0)) == int(pinned_id) + ) + except StopIteration: + pinned_global = None + if pinned_global is not None and not _cfg_supports_image_input( + pinned_global + ): + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, @@ -322,11 +408,19 @@ async def resolve_or_get_pinned_llm_config_id( False if force_repin_free else await _is_premium_eligible(session, user_id) ) if premium_eligible: - eligible = candidates + premium_candidates = [c for c in candidates if _tier_of(c) == "premium"] + preferred_premium = [ + c for c in premium_candidates if _is_preferred_premium_auto_config(c) + ] + eligible = preferred_premium or premium_candidates else: eligible = [c for c in candidates if _tier_of(c) != "premium"] if not eligible: + if requires_image_input: + raise ValueError( + "Auto mode could not find a vision-capable LLM config for this user and quota state" + ) raise ValueError( "Auto mode could not find an eligible LLM config for this user and quota state" ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py new file mode 100644 index 000000000..92ccd6a78 --- /dev/null +++ b/surfsense_backend/app/services/billable_calls.py @@ -0,0 +1,566 @@ +""" +Per-call billable wrapper for image generation, vision LLM extraction, and +any other short-lived premium operation that must charge against the user's +shared premium credit pool. + +The ``billable_call`` async context manager encapsulates the standard +"reserve → execute → finalize / release → record audit row" lifecycle in a +single primitive so callers (the image-generation REST route and the +vision-LLM wrapper used during indexing) don't have to re-implement it. + +KEY DESIGN POINTS (issue A, B): + +1. **Session isolation.** ``billable_call`` takes no caller transaction. + All ``TokenQuotaService.premium_*`` calls and the audit-row insert run + inside their own session context. Route callers use + ``shielded_async_session()`` by default; Celery callers can provide a + worker-loop-safe session factory. This guarantees that quota + commit/rollback can never accidentally flush or roll back rows the caller + has staged in its main session (e.g. a freshly-created + ``ImageGeneration`` row). + +2. **ContextVar safety.** The accumulator is scoped via + :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a + nested ``billable_call`` inside an outer chat turn cannot corrupt the + chat turn's accumulator. + +3. **Free configs are still audited.** Free calls bypass the reserve / + finalize dance entirely but still record a ``TokenUsage`` audit row with + the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution + pipeline complete for analytics even when nothing is debited. + +4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is + responsible for translating that into HTTP 402. We *do not* catch the + denial inside ``billable_call`` — letting it propagate also prevents + the image-generation route from creating an ``ImageGeneration`` row + for a request that never actually ran. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import shielded_async_session +from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, +) +from app.services.token_tracking_service import ( + TurnTokenAccumulator, + record_token_usage, + scoped_turn, +) + +logger = logging.getLogger(__name__) + +AUDIT_TIMEOUT_SECONDS = 10.0 +BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset( + {"video_presentation_generation", "podcast_generation"} +) +BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]] + + +class QuotaInsufficientError(Exception): + """Raised when ``TokenQuotaService.premium_reserve`` denies a billable + call because the user has exhausted their premium credit pool. + + The route handler should catch this and return HTTP 402 Payment + Required (or the equivalent for the surface area). Outside of the HTTP + layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing) + callers may catch this and degrade gracefully — e.g. fall back to OCR + when vision is unavailable. + """ + + def __init__( + self, + *, + usage_type: str, + used_micros: int, + limit_micros: int, + remaining_micros: int, + ) -> None: + self.usage_type = usage_type + self.used_micros = used_micros + self.limit_micros = limit_micros + self.remaining_micros = remaining_micros + super().__init__( + f"Premium credit exhausted for {usage_type}: " + f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" + ) + + +class BillingSettlementError(Exception): + """Raised when a premium call completed but credit settlement failed.""" + + def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None: + self.usage_type = usage_type + self.user_id = user_id + super().__init__( + f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}" + ) + + +async def _rollback_safely(session: AsyncSession) -> None: + rollback = getattr(session, "rollback", None) + if rollback is not None: + with suppress(Exception): + await rollback() + + +async def _record_audit_best_effort( + *, + session_factory: BillableSessionFactory, + usage_type: str, + search_space_id: int, + user_id: UUID, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + cost_micros: int, + model_breakdown: dict[str, Any], + call_details: dict[str, Any] | None, + thread_id: int | None, + message_id: int | None, + audit_label: str, + timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, +) -> None: + """Persist a TokenUsage row without letting audit failure block callers. + + Premium settlement is mandatory, but TokenUsage is an audit trail. If the + audit insert or commit hangs, user-facing artifacts such as videos and + podcasts must still be able to transition to READY after settlement. + """ + audit_thread_id = ( + None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id + ) + + async def _persist() -> None: + logger.info( + "[billable_call] audit start label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + async with session_factory() as audit_session: + try: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost_micros=cost_micros, + model_breakdown=model_breakdown, + call_details=call_details, + thread_id=audit_thread_id, + message_id=message_id, + ) + logger.info( + "[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + await audit_session.commit() + logger.info( + "[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + except BaseException: + await _rollback_safely(audit_session) + raise + + try: + await asyncio.wait_for(_persist(), timeout=timeout_seconds) + except TimeoutError: + logger.warning( + "[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s " + "timeout=%.1fs total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + timeout_seconds, + total_tokens, + cost_micros, + ) + except Exception: + logger.exception( + "[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + + +@asynccontextmanager +async def billable_call( + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None = None, + quota_reserve_micros_override: int | None = None, + usage_type: str, + thread_id: int | None = None, + message_id: int | None = None, + call_details: dict[str, Any] | None = None, + billable_session_factory: BillableSessionFactory | None = None, + audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, +) -> AsyncIterator[TurnTokenAccumulator]: + """Wrap a single billable LLM/image call. + + Args: + user_id: Owner of the credit pool to debit. For vision-LLM during + indexing this is the *search-space owner* (issue M), not the + triggering user. + search_space_id: Required — recorded on the ``TokenUsage`` audit row. + billing_tier: ``"premium"`` debits; anything else (``"free"``) skips + the reserve/finalize dance but still records an audit row with + the captured cost. + base_model: Used by :func:`estimate_call_reserve_micros` to compute + a worst-case reservation from LiteLLM's pricing table. + quota_reserve_tokens: Optional per-config override for the chat-style + reserve estimator (vision LLM uses this). + quota_reserve_micros_override: Optional flat micro-USD reservation + (image generation uses this — its cost shape is per-image, not + per-token). + usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc. + Recorded on the ``TokenUsage`` row. + thread_id, message_id: Optional FK columns on ``TokenUsage``. + call_details: Optional per-call metadata (model name, parameters) + forwarded to ``record_token_usage``. + billable_session_factory: Optional async context factory used for + reserve/finalize/release/audit sessions. Defaults to + ``shielded_async_session`` for route callers; Celery callers pass + a worker-loop-safe session factory. + audit_timeout_seconds: Upper bound for TokenUsage audit persistence. + Audit failure is best-effort and does not undo successful + settlement. + + Yields: + The ``TurnTokenAccumulator`` scoped to this call. The caller invokes + the underlying LLM/image API while inside the ``async with``; the + ``TokenTrackingCallback`` populates the accumulator automatically. + + Raises: + QuotaInsufficientError: when premium and ``premium_reserve`` denies. + """ + is_premium = billing_tier == "premium" + session_factory = billable_session_factory or shielded_async_session + + async with scoped_turn() as acc: + # ---------- Free path: just audit ------------------------------- + if not is_premium: + try: + yield acc + finally: + # Always audit, even on exception, so we capture cost when + # provider returns successfully but the caller raises later. + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=acc.total_cost_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="free", + timeout_seconds=audit_timeout_seconds, + ) + return + + # ---------- Premium path: reserve → execute → finalize ---------- + if quota_reserve_micros_override is not None: + reserve_micros = max(1, int(quota_reserve_micros_override)) + else: + reserve_micros = estimate_call_reserve_micros( + base_model=base_model or "", + quota_reserve_tokens=quota_reserve_tokens, + ) + + request_id = str(uuid4()) + + async with session_factory() as quota_session: + reserve_result = await TokenQuotaService.premium_reserve( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + reserve_micros=reserve_micros, + ) + + if not reserve_result.allowed: + logger.info( + "[billable_call] reserve DENIED user=%s usage_type=%s " + "reserve=%d used=%d limit=%d remaining=%d", + user_id, + usage_type, + reserve_micros, + reserve_result.used, + reserve_result.limit, + reserve_result.remaining, + ) + raise QuotaInsufficientError( + usage_type=usage_type, + used_micros=reserve_result.used, + limit_micros=reserve_result.limit, + remaining_micros=reserve_result.remaining, + ) + + logger.info( + "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d " + "(remaining=%d)", + user_id, + usage_type, + reserve_micros, + reserve_result.remaining, + ) + + try: + yield acc + except BaseException: + # Release on any failure (including QuotaInsufficientError raised + # from a downstream call, asyncio cancellation, etc.). We use + # BaseException so cancellation also releases. + try: + async with session_factory() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] premium_release failed for user=%s " + "reserve_micros=%d (reservation will be GC'd by quota " + "reconciliation if/when implemented)", + user_id, + reserve_micros, + ) + raise + + # ---------- Success: finalize + audit ---------------------------- + actual_micros = acc.total_cost_micros + try: + logger.info( + "[billable_call] finalize start user=%s usage_type=%s actual=%d " + "reserved=%d thread=%s", + user_id, + usage_type, + actual_micros, + reserve_micros, + thread_id, + ) + async with session_factory() as quota_session: + final_result = await TokenQuotaService.premium_finalize( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + actual_micros=actual_micros, + reserved_micros=reserve_micros, + ) + logger.info( + "[billable_call] finalize user=%s usage_type=%s actual=%d " + "reserved=%d → used=%d/%d (remaining=%d)", + user_id, + usage_type, + actual_micros, + reserve_micros, + final_result.used, + final_result.limit, + final_result.remaining, + ) + except Exception as finalize_exc: + # Last-ditch: if finalize itself fails, we must at least release + # so the reservation doesn't leak. + logger.exception( + "[billable_call] premium_finalize failed for user=%s; " + "attempting release", + user_id, + ) + try: + async with session_factory() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] release after finalize failure ALSO failed " + "for user=%s", + user_id, + ) + raise BillingSettlementError( + usage_type=usage_type, + user_id=user_id, + cause=finalize_exc, + ) from finalize_exc + + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=actual_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="premium", + timeout_seconds=audit_timeout_seconds, + ) + + +async def _resolve_agent_billing_for_search_space( + session: AsyncSession, + search_space_id: int, + *, + thread_id: int | None = None, +) -> tuple[UUID, str, str]: + """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space + agent LLM. + + Used by Celery tasks (podcast generation, video presentation) to bill the + search-space owner's premium credit pool when the agent LLM is premium. + + Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + + - Search space not found / no ``agent_llm_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + * ``thread_id`` is set: delegate to + ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and + recurse into the resolved id. Reuses chat's existing pin if present + so the same model bills for chat + downstream podcast/video. If the + user is not premium-eligible, the pin service auto-restricts to free + deployments — denial only happens later in + ``billable_call.premium_reserve`` if the pin really is premium and + credit ran out mid-flow. + * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat + for any future direct-API path; today both Celery tasks always pass + ``thread_id``. + - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]`` + (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), + ``base_model = litellm_params.get("base_model") or model_name`` — + NOT provider-prefixed, matching chat's cost-map lookup convention. + - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches + ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); + ``base_model`` from ``litellm_params`` or ``model_name``. + + Note on imports: ``llm_service``, ``auto_model_pin_service``, and + ``llm_router_service`` are imported lazily inside the function body to + avoid hoisting litellm side-effects (``litellm.callbacks = + [token_tracker]``, ``litellm.drop_params``, etc.) into + ``billable_calls.py``'s module load path. + """ + from sqlalchemy import select + + from app.db import NewLLMConfig, SearchSpace + + result = await session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if search_space is None: + raise ValueError(f"Search space {search_space_id} not found") + + agent_llm_id = search_space.agent_llm_id + if agent_llm_id is None: + raise ValueError( + f"Search space {search_space_id} has no agent_llm_id configured" + ) + + owner_user_id: UUID = search_space.user_id + + from app.services.auto_model_pin_service import ( + AUTO_FASTEST_ID, + resolve_or_get_pinned_llm_config_id, + ) + + if agent_llm_id == AUTO_FASTEST_ID: + if thread_id is None: + return owner_user_id, "free", "auto" + try: + resolution = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=str(owner_user_id), + selected_llm_config_id=AUTO_FASTEST_ID, + ) + except ValueError: + logger.warning( + "[agent_billing] Auto-mode pin resolution failed for " + "search_space=%s thread=%s; falling back to free", + search_space_id, + thread_id, + exc_info=True, + ) + return owner_user_id, "free", "auto" + agent_llm_id = resolution.resolved_llm_config_id + + if agent_llm_id < 0: + from app.services.llm_service import get_global_llm_config + + cfg = get_global_llm_config(agent_llm_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + litellm_params = cfg.get("litellm_params") or {} + base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" + return owner_user_id, billing_tier, base_model + + nlc_result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == agent_llm_id, + NewLLMConfig.search_space_id == search_space_id, + ) + ) + nlc = nlc_result.scalars().first() + base_model = "" + if nlc is not None: + litellm_params = nlc.litellm_params or {} + base_model = litellm_params.get("base_model") or nlc.model_name or "" + return owner_user_id, "free", base_model + + +__all__ = [ + "BillingSettlementError", + "QuotaInsufficientError", + "_resolve_agent_billing_for_search_space", + "billable_call", +] + + +# Re-export the config knob so callers don't have to import config just for +# the default image reserve. +DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index a8abe4aa8..edfab1d15 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -408,12 +408,37 @@ class ComposioService: files = [] next_token = None if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) # Try direct access first, then nested - files = data.get("files", []) or data.get("data", {}).get("files", []) + files = ( + data.get("files", []) + or ( + inner_data.get("files", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("files", []) + ) next_token = ( data.get("nextPageToken") or data.get("next_page_token") - or data.get("data", {}).get("nextPageToken") + or ( + inner_data.get("nextPageToken") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("next_page_token") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("nextPageToken") + or response_data.get("next_page_token") ) elif isinstance(data, list): files = data @@ -819,24 +844,61 @@ class ComposioService: next_token = None result_size_estimate = None if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) messages = ( data.get("messages", []) - or data.get("data", {}).get("messages", []) + or ( + inner_data.get("messages", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("messages", []) or data.get("emails", []) + or ( + inner_data.get("emails", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("emails", []) ) # Check for pagination token in various possible locations next_token = ( data.get("nextPageToken") or data.get("next_page_token") - or data.get("data", {}).get("nextPageToken") - or data.get("data", {}).get("next_page_token") + or ( + inner_data.get("nextPageToken") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("next_page_token") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("nextPageToken") + or response_data.get("next_page_token") ) # Extract resultSizeEstimate if available (Gmail API provides this) result_size_estimate = ( data.get("resultSizeEstimate") or data.get("result_size_estimate") - or data.get("data", {}).get("resultSizeEstimate") - or data.get("data", {}).get("result_size_estimate") + or ( + inner_data.get("resultSizeEstimate") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("result_size_estimate") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("resultSizeEstimate") + or response_data.get("result_size_estimate") ) elif isinstance(data, list): messages = data @@ -864,7 +926,7 @@ class ComposioService: try: result = await self.execute_tool( connected_account_id=connected_account_id, - tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID", + tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID", params={"message_id": message_id}, # snake_case entity_id=entity_id, ) @@ -872,7 +934,13 @@ class ComposioService: if not result.get("success"): return None, result.get("error", "Unknown error") - return result.get("data"), None + data = result.get("data") + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + return inner_data.get("response_data", inner_data), None + + return data, None except Exception as e: logger.error(f"Failed to get Gmail message detail: {e!s}") @@ -928,10 +996,27 @@ class ComposioService: # Try different possible response structures events = [] if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) events = ( data.get("items", []) - or data.get("data", {}).get("items", []) + or ( + inner_data.get("items", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("items", []) or data.get("events", []) + or ( + inner_data.get("events", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("events", []) ) elif isinstance(data, list): events = data diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 7c55da2e5..45bcfd00f 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -1,6 +1,8 @@ import asyncio +import os import time from datetime import datetime +from threading import Lock from typing import Any import httpx @@ -2769,12 +2771,22 @@ class ConnectorService: """ Get all available (enabled) connector types for a search space. + Phase 1.4: results are cached per ``search_space_id`` for + :data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session + identity — the cached value is plain data, safe to share across + requests. Invalidate on connector add/update/delete via + :func:`invalidate_connector_discovery_cache`. + Args: search_space_id: The search space ID Returns: List of SearchSourceConnectorType enums for enabled connectors """ + cached = _get_cached_connectors(search_space_id) + if cached is not None: + return list(cached) + query = ( select(SearchSourceConnector.connector_type) .filter( @@ -2784,8 +2796,9 @@ class ConnectorService: ) result = await self.session.execute(query) - connector_types = result.scalars().all() - return list(connector_types) + connector_types = list(result.scalars().all()) + _set_cached_connectors(search_space_id, connector_types) + return connector_types async def get_available_document_types( self, @@ -2794,12 +2807,22 @@ class ConnectorService: """ Get all document types that have at least one document in the search space. + Phase 1.4: cached per ``search_space_id`` for + :data:`_DISCOVERY_TTL_SECONDS`. Invalidate via + :func:`invalidate_connector_discovery_cache` when a connector + finishes indexing new documents (or document types are otherwise + added/removed). + Args: search_space_id: The search space ID Returns: List of document type strings that have documents indexed """ + cached = _get_cached_doc_types(search_space_id) + if cached is not None: + return list(cached) + from sqlalchemy import distinct from app.db import Document @@ -2809,5 +2832,164 @@ class ConnectorService: ) result = await self.session.execute(query) - doc_types = result.scalars().all() - return [str(dt) for dt in doc_types] + doc_types = [str(dt) for dt in result.scalars().all()] + _set_cached_doc_types(search_space_id, doc_types) + return doc_types + + +# --------------------------------------------------------------------------- +# Connector / document-type discovery TTL cache (Phase 1.4) +# --------------------------------------------------------------------------- +# +# Both ``get_available_connectors`` and ``get_available_document_types`` are +# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query +# hits Postgres and contributes to per-turn agent build latency. Their +# results change infrequently — only when the user adds/edits/removes a +# connector, or when an indexer commits a new document type. A short TTL +# cache (default 30s, env-tunable) collapses N concurrent calls into one +# DB roundtrip with bounded staleness. +# +# Invalidation: connector mutation routes (create / update / delete) call +# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the +# entry for the affected space. Multi-replica deployments still pay one +# DB roundtrip per replica per TTL window, which is fine — staleness is +# bounded and the alternative (cross-replica fanout) is not worth the +# coupling here. + +_DISCOVERY_TTL_SECONDS: float = float( + os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30") +) + +# Per-search-space caches. Keyed by ``search_space_id``; value is +# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock — +# read-mostly workload, sub-microsecond contention. +_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {} +_doc_types_cache: dict[int, tuple[float, list[str]]] = {} +_cache_lock = Lock() + + +def _get_cached_connectors( + search_space_id: int, +) -> list[SearchSourceConnectorType] | None: + if _DISCOVERY_TTL_SECONDS <= 0: + return None + with _cache_lock: + entry = _connectors_cache.get(search_space_id) + if entry is None: + return None + expires_at, payload = entry + if time.monotonic() >= expires_at: + _connectors_cache.pop(search_space_id, None) + return None + return payload + + +def _set_cached_connectors( + search_space_id: int, payload: list[SearchSourceConnectorType] +) -> None: + if _DISCOVERY_TTL_SECONDS <= 0: + return + expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS + with _cache_lock: + _connectors_cache[search_space_id] = (expires_at, list(payload)) + + +def _get_cached_doc_types(search_space_id: int) -> list[str] | None: + if _DISCOVERY_TTL_SECONDS <= 0: + return None + with _cache_lock: + entry = _doc_types_cache.get(search_space_id) + if entry is None: + return None + expires_at, payload = entry + if time.monotonic() >= expires_at: + _doc_types_cache.pop(search_space_id, None) + return None + return payload + + +def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None: + if _DISCOVERY_TTL_SECONDS <= 0: + return + expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS + with _cache_lock: + _doc_types_cache[search_space_id] = (expires_at, list(payload)) + + +def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None: + """Drop cached discovery results for ``search_space_id`` (or all spaces). + + Connector CRUD routes / indexer pipelines call this when they mutate + the rows backing :func:`ConnectorService.get_available_connectors` / + :func:`get_available_document_types`. ``None`` clears every space — + useful in tests and on bulk imports. + """ + with _cache_lock: + if search_space_id is None: + _connectors_cache.clear() + _doc_types_cache.clear() + else: + _connectors_cache.pop(search_space_id, None) + _doc_types_cache.pop(search_space_id, None) + + +def _invalidate_connectors_only(search_space_id: int | None = None) -> None: + with _cache_lock: + if search_space_id is None: + _connectors_cache.clear() + else: + _connectors_cache.pop(search_space_id, None) + + +def _invalidate_doc_types_only(search_space_id: int | None = None) -> None: + with _cache_lock: + if search_space_id is None: + _doc_types_cache.clear() + else: + _doc_types_cache.pop(search_space_id, None) + + +def _register_invalidation_listeners() -> None: + """Wire SQLAlchemy ORM events so cache stays consistent automatically. + + Listening on ``after_insert`` / ``after_update`` / ``after_delete`` + means every successful INSERT/UPDATE/DELETE that goes through the ORM + invalidates the affected search space's cached discovery payload — + no need to sprinkle ``invalidate_*`` calls across 30+ connector + routes. Bulk operations that bypass the ORM (e.g. + ``session.execute(insert(...))`` without a mapped object) still need + explicit invalidation; document indexers already commit through the + ORM so document-type discovery is covered. + """ + from sqlalchemy import event + + # Imported here (not at module top) to avoid a circular import: + # app.services.connector_service is itself imported from app.db's + # ecosystem indirectly via several CRUD modules. + from app.db import Document, SearchSourceConnector + + def _connector_changed(_mapper, _connection, target) -> None: + sid = getattr(target, "search_space_id", None) + if sid is not None: + _invalidate_connectors_only(int(sid)) + + def _document_changed(_mapper, _connection, target) -> None: + sid = getattr(target, "search_space_id", None) + if sid is not None: + _invalidate_doc_types_only(int(sid)) + + for evt in ("after_insert", "after_update", "after_delete"): + event.listen(SearchSourceConnector, evt, _connector_changed) + event.listen(Document, evt, _document_changed) + + +try: + _register_invalidation_listeners() +except Exception: # pragma: no cover - defensive; never block module import + import logging as _logging + + _logging.getLogger(__name__).exception( + "Failed to register connector discovery cache invalidation listeners; " + "stale cache risk: explicit invalidate_connector_discovery_cache calls " + "may be required." + ) diff --git a/surfsense_backend/app/services/gmail/tool_metadata_service.py b/surfsense_backend/app/services/gmail/tool_metadata_service.py index c903e24af..4855c1cc9 100644 --- a/surfsense_backend/app/services/gmail/tool_metadata_service.py +++ b/surfsense_backend/app/services/gmail/tool_metadata_service.py @@ -17,7 +17,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -78,14 +78,49 @@ class GmailToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session - async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: - if ( + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + return cca_id + + def _unwrap_composio_data(self, data: Any) -> Any: + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + return data + + async def _execute_composio_gmail_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict[str, Any], + ) -> tuple[Any, str | None]: + result = await ComposioService().execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Gmail error") + return self._unwrap_composio_data(result.get("data")), None + + async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: + if self._is_composio_connector(connector): + raise ValueError( + "Composio Gmail connectors must use Composio tool execution" + ) config_data = dict(connector.config) @@ -139,6 +174,12 @@ class GmailToolMetadataService: if not connector: return True + if self._is_composio_connector(connector): + _profile, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_GET_PROFILE", {"user_id": "me"} + ) + return bool(error) + creds = await self._build_credentials(connector) service = build("gmail", "v1", credentials=creds) await asyncio.get_event_loop().run_in_executor( @@ -221,14 +262,21 @@ class GmailToolMetadataService: ) connector = result.scalar_one_or_none() if connector: - creds = await self._build_credentials(connector) - service = build("gmail", "v1", credentials=creds) - profile = await asyncio.get_event_loop().run_in_executor( - None, - lambda service=service: ( - service.users().getProfile(userId="me").execute() - ), - ) + if self._is_composio_connector(connector): + profile, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_GET_PROFILE", {"user_id": "me"} + ) + if error: + raise RuntimeError(error) + else: + creds = await self._build_credentials(connector) + service = build("gmail", "v1", credentials=creds) + profile = await asyncio.get_event_loop().run_in_executor( + None, + lambda service=service: ( + service.users().getProfile(userId="me").execute() + ), + ) acc_dict["email"] = profile.get("emailAddress", "") except Exception: logger.warning( @@ -298,6 +346,23 @@ class GmailToolMetadataService: Returns ``None`` on any failure so callers can degrade gracefully. """ try: + if self._is_composio_connector(connector): + if not draft_id: + draft_id = await self._find_composio_draft_id(connector, message_id) + if not draft_id: + return None + + draft, error = await self._execute_composio_gmail_tool( + connector, + "GMAIL_GET_DRAFT", + {"user_id": "me", "draft_id": draft_id, "format": "full"}, + ) + if error or not isinstance(draft, dict): + return None + + payload = draft.get("message", {}).get("payload", {}) + return self._extract_body_from_payload(payload) + creds = await self._build_credentials(connector) service = build("gmail", "v1", credentials=creds) @@ -326,6 +391,33 @@ class GmailToolMetadataService: ) return None + async def _find_composio_draft_id( + self, connector: SearchSourceConnector, message_id: str + ) -> str | None: + page_token = "" + while True: + params: dict[str, Any] = { + "user_id": "me", + "max_results": 100, + "verbose": False, + } + if page_token: + params["page_token"] = page_token + + data, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_LIST_DRAFTS", params + ) + if error or not isinstance(data, dict): + return None + + for draft in data.get("drafts", []): + if draft.get("message", {}).get("id") == message_id: + return draft.get("id") + + page_token = data.get("nextPageToken") or data.get("next_page_token") or "" + if not page_token: + return None + async def _find_draft_id(self, service: Any, message_id: str) -> str | None: """Resolve a draft ID from its message ID by scanning drafts.list.""" try: diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 20426f3bc..602a55738 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -14,6 +14,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) +from app.services.composio_service import ComposioService from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -21,7 +22,6 @@ from app.utils.document_converters import ( generate_document_summary, generate_unique_identifier_hash, ) -from app.utils.google_credentials import build_composio_credentials logger = logging.getLogger(__name__) @@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService: logger.warning("Document %s not found in KB", document_id) return {"status": "not_indexed"} - creds = await self._build_credentials_for_connector(connector_id) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - calendar_id = (document.document_metadata or {}).get( "calendar_id" ) or "primary" - live_event = await loop.run_in_executor( - None, - lambda: ( - service.events() - .get(calendarId=calendar_id, eventId=event_id) - .execute() - ), - ) + connector = await self._get_connector(connector_id) + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_EVENTS_GET", + params={"calendar_id": calendar_id, "event_id": event_id}, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get("error", "Unknown Composio Calendar error") + ) + live_event = composio_result.get("data", {}) + if isinstance(live_event, dict): + live_event = live_event.get("data", live_event) + if isinstance(live_event, dict): + live_event = live_event.get("response_data", live_event) + else: + creds = await self._build_credentials_for_connector(connector_id) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + live_event = await loop.run_in_executor( + None, + lambda: ( + service.events() + .get(calendarId=calendar_id, eventId=event_id) + .execute() + ), + ) event_summary = live_event.get("summary", "") description = live_event.get("description", "") @@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService: await self.db_session.rollback() return {"status": "error", "message": str(e)} - async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: + async def _get_connector(self, connector_id: int) -> SearchSourceConnector: result = await self.db_session.execute( select(SearchSourceConnector).where( SearchSourceConnector.id == connector_id @@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService: connector = result.scalar_one_or_none() if not connector: raise ValueError(f"Connector {connector_id} not found") + return connector + async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: + connector = await self._get_connector(connector_id) if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) - raise ValueError("Composio connected_account_id not found") + raise ValueError( + "Composio Calendar connectors must use Composio tool execution" + ) config_data = dict(connector.config) diff --git a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py index c7bfe1d50..7e50ab039 100644 --- a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py +++ b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py @@ -16,7 +16,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session - async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: - if ( + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: raise ValueError("Composio connected_account_id not found") + return cca_id + + async def _execute_composio_calendar_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict, + ) -> tuple[dict | list | None, str | None]: + service = ComposioService() + result = await service.execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Calendar error") + + data = result.get("data") + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner), None + return inner, None + return data, None + + async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: + if self._is_composio_connector(connector): + raise ValueError( + "Composio Calendar connectors must use Composio tool execution" + ) config_data = dict(connector.config) @@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService: if not connector: return True + if self._is_composio_connector(connector): + _data, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_GET_CALENDAR", + {"calendar_id": "primary"}, + ) + return bool(error) + creds = await self._build_credentials(connector) loop = asyncio.get_event_loop() await loop.run_in_executor( @@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService: timezone_str = "" if connector: try: - creds = await self._build_credentials(connector) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if self._is_composio_connector(connector): + cal_list, cal_error = await self._execute_composio_calendar_tool( + connector, "GOOGLECALENDAR_LIST_CALENDARS", {} + ) + if cal_error: + raise RuntimeError(cal_error) + ( + settings, + settings_error, + ) = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_SETTINGS_GET", + {"setting": "timezone"}, + ) + if not settings_error and isinstance(settings, dict): + timezone_str = settings.get("value", "") + else: + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) - cal_list = await loop.run_in_executor( - None, lambda: service.calendarList().list().execute() - ) - for cal in cal_list.get("items", []): + cal_list = await loop.run_in_executor( + None, lambda: service.calendarList().list().execute() + ) + + tz_setting = await loop.run_in_executor( + None, + lambda: service.settings().get(setting="timezone").execute(), + ) + timezone_str = tz_setting.get("value", "") + + calendar_items = [] + if isinstance(cal_list, dict): + calendar_items = ( + cal_list.get("items") or cal_list.get("calendars") or [] + ) + elif isinstance(cal_list, list): + calendar_items = cal_list + + for cal in calendar_items: calendars.append( { "id": cal.get("id", ""), @@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService: "primary": cal.get("primary", False), } ) - - tz_setting = await loop.run_in_executor( - None, - lambda: service.settings().get(setting="timezone").execute(), - ) - timezone_str = tz_setting.get("value", "") except Exception: logger.warning( "Failed to fetch calendars/timezone for connector %s", @@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService: event_dict = event.to_dict() try: - creds = await self._build_credentials(connector) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) calendar_id = event.calendar_id or "primary" - live_event = await loop.run_in_executor( - None, - lambda: ( - service.events() - .get(calendarId=calendar_id, eventId=event.event_id) - .execute() - ), - ) + if self._is_composio_connector(connector): + live_event, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_EVENTS_GET", + {"calendar_id": calendar_id, "event_id": event.event_id}, + ) + if error: + raise RuntimeError(error) + else: + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + live_event = await loop.run_in_executor( + None, + lambda: ( + service.events() + .get(calendarId=calendar_id, eventId=event.event_id) + .execute() + ), + ) event_dict["summary"] = live_event.get("summary", event_dict["summary"]) event_dict["description"] = live_event.get( @@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService: ) -> dict: resolved = await self._resolve_event(search_space_id, user_id, event_ref) if not resolved: + live_resolved = await self._resolve_live_event( + search_space_id, user_id, event_ref + ) + if not live_resolved: + return { + "error": ( + f"Event '{event_ref}' not found in your indexed or live Google Calendar events. " + "This could mean: (1) the event doesn't exist, " + "(2) the event name is different, or " + "(3) the connected calendar account cannot access it." + ) + } + + connector, live_event = live_resolved + account = GoogleCalendarAccount.from_connector(connector) + acc_dict = account.to_dict() + auth_expired = await self._check_account_health(connector.id) + acc_dict["auth_expired"] = auth_expired + if auth_expired: + await self._persist_auth_expired(connector.id) + return { - "error": ( - f"Event '{event_ref}' not found in your indexed Google Calendar events. " - "This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the event name is different." - ) + "account": acc_dict, + "event": self._event_dict_from_live_event(live_event), } document, connector = resolved @@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService: if row: return row[0], row[1] return None + + async def _resolve_live_event( + self, search_space_id: int, user_id: str, event_ref: str + ) -> tuple[SearchSourceConnector, dict] | None: + result = await self._db_session.execute( + select(SearchSourceConnector) + .filter( + and_( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES), + ) + ) + .order_by(SearchSourceConnector.last_indexed_at.desc()) + ) + connectors = result.scalars().all() + + for connector in connectors: + try: + events = await self._search_live_events(connector, event_ref) + except Exception: + logger.warning( + "Failed to search live calendar events for connector %s", + connector.id, + exc_info=True, + ) + continue + + if not events: + continue + + normalized_ref = event_ref.strip().lower() + exact_match = next( + ( + event + for event in events + if event.get("summary", "").strip().lower() == normalized_ref + ), + None, + ) + return connector, exact_match or events[0] + + return None + + async def _search_live_events( + self, connector: SearchSourceConnector, event_ref: str + ) -> list[dict]: + if self._is_composio_connector(connector): + data, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_EVENTS_LIST", + { + "calendar_id": "primary", + "q": event_ref, + "max_results": 10, + "single_events": True, + "order_by": "startTime", + }, + ) + if error: + raise RuntimeError(error) + if isinstance(data, dict): + return data.get("items") or data.get("events") or [] + return data if isinstance(data, list) else [] + + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + response = await loop.run_in_executor( + None, + lambda: ( + service.events() + .list( + calendarId="primary", + q=event_ref, + maxResults=10, + singleEvents=True, + orderBy="startTime", + ) + .execute() + ), + ) + return response.get("items", []) + + def _event_dict_from_live_event(self, event: dict) -> dict: + start_data = event.get("start", {}) + end_data = event.get("end", {}) + return { + "event_id": event.get("id", ""), + "summary": event.get("summary", "No Title"), + "start": start_data.get("dateTime", start_data.get("date", "")), + "end": end_data.get("dateTime", end_data.get("date", "")), + "description": event.get("description", ""), + "location": event.get("location", ""), + "attendees": [ + { + "email": attendee.get("email", ""), + "responseStatus": attendee.get("responseStatus", ""), + } + for attendee in event.get("attendees", []) + ], + "calendar_id": event.get("calendarId", "primary"), + "document_id": None, + "indexed_at": None, + } diff --git a/surfsense_backend/app/services/google_drive/tool_metadata_service.py b/surfsense_backend/app/services/google_drive/tool_metadata_service.py index 221bee14a..0f654bc78 100644 --- a/surfsense_backend/app/services/google_drive/tool_metadata_service.py +++ b/surfsense_backend/app/services/google_drive/tool_metadata_service.py @@ -13,7 +13,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + return cca_id + + async def _execute_composio_drive_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict, + ) -> tuple[dict | list | None, str | None]: + result = await ComposioService().execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Drive error") + data = result.get("data") + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner), None + return inner, None + return data, None + async def get_creation_context(self, search_space_id: int, user_id: str) -> dict: accounts = await self._get_google_drive_accounts(search_space_id, user_id) @@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService: if not connector: return True - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if self._is_composio_connector(connector): + _data, error = await self._execute_composio_drive_tool( + connector, + "GOOGLEDRIVE_LIST_FILES", + { + "q": "trashed = false", + "page_size": 1, + "fields": "files(id)", + }, + ) + return bool(error) client = GoogleDriveClient( session=self._db_session, connector_id=connector_id, - credentials=pre_built_creds, ) await client.list_files( query="trashed = false", page_size=1, fields="files(id)" @@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService: parent_folders[connector_id] = [] continue - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if self._is_composio_connector(connector): + data, error = await self._execute_composio_drive_tool( + connector, + "GOOGLEDRIVE_LIST_FILES", + { + "q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents", + "fields": "files(id,name)", + "page_size": 50, + }, + ) + if error: + logger.warning( + "Failed to list folders for connector %s: %s", + connector_id, + error, + ) + parent_folders[connector_id] = [] + continue + folders = [] + if isinstance(data, dict): + folders = data.get("files", []) + elif isinstance(data, list): + folders = data + parent_folders[connector_id] = [ + {"folder_id": f["id"], "name": f["name"]} + for f in folders + if f.get("id") and f.get("name") + ] + continue client = GoogleDriveClient( session=self._db_session, connector_id=connector_id, - credentials=pre_built_creds, ) folders, _, error = await client.list_files( diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index f45a6ab63..b4de2a0bf 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,6 +20,8 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing @@ -152,12 +154,12 @@ class ImageGenRouterService: return None # Build model string + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] else: - provider = config.get("provider", "").upper() provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" + model_string = f"{provider_prefix}/{config['model_name']}" # Build litellm params litellm_params: dict[str, Any] = { @@ -165,9 +167,16 @@ class ImageGenRouterService: "api_key": config.get("api_key"), } - # Add optional api_base - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + # Resolve ``api_base`` so deployments don't silently inherit + # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against + # the wrong provider (see ``provider_api_base`` docstring). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base # Add api_version (required for Azure) if config.get("api_version"): diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 8a7b2919a..d220aa346 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -134,42 +134,14 @@ PROVIDER_MAP = { } -# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when -# a global LLM config does *not* specify ``api_base``: without this, LiteLLM -# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``, -# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku`` -# request to an Azure endpoint, which then 404s with ``Resource not found``. -# Only providers with a well-known, stable public base URL are listed here — -# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, -# huggingface, databricks, cloudflare, replicate) are intentionally omitted -# so their existing config-driven behaviour is preserved. -PROVIDER_DEFAULT_API_BASE = { - "openrouter": "https://openrouter.ai/api/v1", - "groq": "https://api.groq.com/openai/v1", - "mistral": "https://api.mistral.ai/v1", - "perplexity": "https://api.perplexity.ai", - "xai": "https://api.x.ai/v1", - "cerebras": "https://api.cerebras.ai/v1", - "deepinfra": "https://api.deepinfra.com/v1/openai", - "fireworks_ai": "https://api.fireworks.ai/inference/v1", - "together_ai": "https://api.together.xyz/v1", - "anyscale": "https://api.endpoints.anyscale.com/v1", - "cometapi": "https://api.cometapi.com/v1", - "sambanova": "https://api.sambanova.ai/v1", -} - - -# Canonical provider → base URL when a config uses a generic ``openai``-style -# prefix but the ``provider`` field tells us which API it really is -# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but -# each has its own base URL). -PROVIDER_KEY_DEFAULT_API_BASE = { - "DEEPSEEK": "https://api.deepseek.com/v1", - "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - "MOONSHOT": "https://api.moonshot.ai/v1", - "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", - "MINIMAX": "https://api.minimax.io/v1", -} +# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were +# hoisted to ``app.services.provider_api_base`` so vision and image-gen +# call sites can share the exact same defense (OpenRouter / Groq / etc. +# 404-ing against an inherited Azure endpoint). Re-exported here for +# backward compatibility with any external import. +from app.services.provider_api_base import ( # noqa: E402 + resolve_api_base, +) class LLMRouterService: @@ -466,14 +438,14 @@ class LLMRouterService: # Resolve ``api_base``. Config value wins; otherwise apply a # provider-aware default so the deployment does not silently # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE`` + # requests to the wrong endpoint. See ``provider_api_base`` # docstring for the motivating bug (OpenRouter models 404-ing # against an Azure endpoint). - api_base = config.get("api_base") - if not api_base: - api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider) - if not api_base: - api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix) + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) if api_base: litellm_params["api_base"] = api_base diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 942a9b7af..ade202c72 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -16,6 +16,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -496,8 +497,14 @@ async def get_vision_llm( - Auto mode (ID 0): VisionLLMRouterService - Global (negative ID): YAML configs - DB (positive ID): VisionLLMConfig table + + Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` + so each ``ainvoke`` debits the search-space owner's premium credit + pool. User-owned BYOK configs and free global configs are returned + unwrapped — they don't consume premium credit (issue M). """ from app.db import VisionLLMConfig + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( VISION_PROVIDER_MAP, VisionLLMRouterService, @@ -519,6 +526,8 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None + owner_user_id = search_space.user_id + if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -526,6 +535,13 @@ async def get_vision_llm( ) return None try: + # Auto mode is currently treated as free at the wrapper + # level — the underlying router can dispatch to either + # premium or free YAML configs but routing decisions are + # opaque. If/when we want to bill Auto-routed vision + # calls we'd need to thread the resolved deployment's + # billing_tier back from the router. For now we keep + # parity with chat Auto, which also doesn't pre-classify. return ChatLiteLLMRouter( router=VisionLLMRouterService.get_router(), streaming=True, @@ -541,29 +557,46 @@ async def get_vision_llm( return None if global_cfg.get("custom_provider"): - model_string = ( - f"{global_cfg['custom_provider']}/{global_cfg['model_name']}" - ) + provider_prefix = global_cfg["custom_provider"] + model_string = f"{provider_prefix}/{global_cfg['model_name']}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( global_cfg["provider"].upper(), global_cfg["provider"].lower(), ) - model_string = f"{prefix}/{global_cfg['model_name']}" + model_string = f"{provider_prefix}/{global_cfg['model_name']}" litellm_kwargs = { "model": model_string, "api_key": global_cfg["api_key"], } - if global_cfg.get("api_base"): - litellm_kwargs["api_base"] = global_cfg["api_base"] + api_base = resolve_api_base( + provider=global_cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=global_cfg.get("api_base"), + ) + if api_base: + litellm_kwargs["api_base"] = api_base if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.new_chat.llm_config import SanitizedChatLiteLLM - return SanitizedChatLiteLLM(**litellm_kwargs) + inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) + billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + if billing_tier == "premium": + return QuotaCheckedVisionLLM( + inner_llm, + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=model_string, + quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + ) + return inner_llm + + # User-owned (positive ID) BYOK configs — always free. result = await session.execute( select(VisionLLMConfig).where( VisionLLMConfig.id == config_id, @@ -578,20 +611,26 @@ async def get_vision_llm( return None if vision_cfg.custom_provider: - model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}" + provider_prefix = vision_cfg.custom_provider + model_string = f"{provider_prefix}/{vision_cfg.model_name}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( vision_cfg.provider.value.upper(), vision_cfg.provider.value.lower(), ) - model_string = f"{prefix}/{vision_cfg.model_name}" + model_string = f"{provider_prefix}/{vision_cfg.model_name}" litellm_kwargs = { "model": model_string, "api_key": vision_cfg.api_key, } - if vision_cfg.api_base: - litellm_kwargs["api_base"] = vision_cfg.api_base + api_base = resolve_api_base( + provider=vision_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=vision_cfg.api_base, + ) + if api_base: + litellm_kwargs["api_base"] = api_base if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 7e856d015..6454e2d58 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -93,6 +93,53 @@ def _is_text_output_model(model: dict) -> bool: return output_mods == ["text"] +def _is_image_output_model(model: dict) -> bool: + """Return True if the model can produce image output. + + OpenRouter's ``architecture.output_modalities`` is a list (e.g. + ``["image"]`` for pure image generators, ``["text", "image"]`` for + multi-modal generators that also emit captions). We accept any model + that can output images; the call site decides whether to use the + image-generation API or chat completion. + """ + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def _is_vision_input_model(model: dict) -> bool: + """Return True if the model can ingest an image AND emit text. + + OpenRouter's ``architecture.input_modalities`` lists what the model + accepts; ``output_modalities`` lists what it produces. A vision LLM + is a model that takes images in and produces text out — i.e. it can + answer questions about a screenshot or extract content from an + image. Pure image-to-image models (e.g. style transfer) and + text-only models are excluded. + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + output_mods = arch.get("output_modalities", []) or [] + return "image" in input_mods and "text" in output_mods + + +def _supports_image_input(model: dict) -> bool: + """Return True if the model accepts ``image`` in its input modalities. + + Differs from :func:`_is_vision_input_model` in that it does NOT + require text output — chat-tab models always emit text already (the + chat catalog filters by ``_is_text_output_model``), so the only + extra capability we need to track per chat config is whether the + model can ingest user-attached images. The chat selector and the + streaming task both key off this flag to prevent hitting an + OpenRouter 404 ``"No endpoints found that support image input"`` + when the user uploads an image and selects a text-only model + (DeepSeek V3, Llama 3.x base, etc.). + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + return "image" in input_mods + + def _supports_tool_calling(model: dict) -> bool: """Return True if the model supports function/tool calling.""" supported = model.get("supported_parameters") or [] @@ -175,6 +222,32 @@ async def _fetch_models_async() -> list[dict] | None: return None +def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]: + """Return a ``{model_id: {"prompt": str, "completion": str}}`` map. + + Pricing values are kept as the raw OpenRouter strings (e.g. + ``"0.000003"``); ``pricing_registration`` converts them to floats + when registering with LiteLLM. Models with missing or malformed + pricing are simply omitted — operator-side risk if any of those are + premium. + """ + pricing: dict[str, dict[str, str]] = {} + for model in raw_models: + model_id = str(model.get("id") or "").strip() + if not model_id: + continue + p = model.get("pricing") or {} + prompt = p.get("prompt") + completion = p.get("completion") + if prompt is None and completion is None: + continue + pricing[model_id] = { + "prompt": str(prompt) if prompt is not None else "", + "completion": str(completion) if completion is not None else "", + } + return pricing + + def _generate_configs( raw_models: list[dict], settings: dict[str, Any], @@ -266,6 +339,13 @@ def _generate_configs( # account-wide quota, so per-deployment routing can't spread load # there — it just drains the shared bucket faster. "router_pool_eligible": tier == "premium", + # Capability flag derived from ``architecture.input_modalities``. + # Read by the new-chat selector to dim image-incompatible models + # when the user has pending image attachments, and by + # ``stream_new_chat`` as a fail-fast safety net before the + # OpenRouter request would otherwise 404 with + # ``"No endpoints found that support image input"``. + "supports_image_input": _supports_image_input(model), _OPENROUTER_DYNAMIC_MARKER: True, # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next @@ -282,6 +362,171 @@ def _generate_configs( return configs +# ID-offset bands used to keep dynamic OpenRouter configs in their own +# namespace per surface. Image / vision get separate bands so a single +# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. +_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 +_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 + + +def _generate_image_gen_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter image-generation models into global image-gen + config dicts (matches the YAML shape consumed by ``image_generation_routes``). + + Filter: + - architecture.output_modalities contains "image" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Notably we *drop* the chat-only filters (``_supports_tool_calling`` and + ``_has_sufficient_context``) because tool calls and context windows are + irrelevant for the ``aimage_generation`` API. ``billing_tier`` is + derived per model the same way as chat (``_openrouter_tier``). + + Cost is intentionally *not* registered with LiteLLM at startup + (``pricing_registration`` skips image gen): OpenRouter image-gen + models are not in LiteLLM's native cost map and OpenRouter populates + ``response_cost`` directly from the response header. A defensive + branch in ``_extract_cost_usd`` handles the rare case where + ``usage.cost`` is missing — see ``token_tracking_service``. + """ + id_offset: int = int( + settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + free_rpm: int = settings.get("free_rpm", 20) + litellm_params: dict = settings.get("litellm_params") or {} + + image_models = [ + m + for m in raw_models + if _is_image_output_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in image_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (image generation)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` and 404 on + # ``image_generation/transformation`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + +def _generate_vision_llm_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter vision-capable LLMs into global vision-LLM config + dicts (matches the YAML shape consumed by ``vision_llm_routes``). + + Filter: + - architecture.input_modalities contains "image" + - architecture.output_modalities contains "text" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Vision-LLM is invoked from the indexer (image extraction during + document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so + the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` + filters do not apply: a small-context vision model that doesn't + advertise tool-calling is still perfectly viable for "describe this + image" prompts. + """ + id_offset: int = int( + settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) + litellm_params: dict = settings.get("litellm_params") or {} + + vision_models = [ + m + for m in raw_models + if _is_vision_input_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in vision_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + pricing = model.get("pricing") or {} + + # Capture per-token prices so ``pricing_registration`` can + # register them with LiteLLM at startup (and so the cost + # estimator in ``estimate_call_reserve_micros`` can resolve + # them at reserve time). + try: + input_cost = float(pricing.get("prompt", 0) or 0) + except (TypeError, ValueError): + input_cost = 0.0 + try: + output_cost = float(pricing.get("completion", 0) or 0) + except (TypeError, ValueError): + output_cost = 0.0 + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (vision)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + "quota_reserve_tokens": quota_reserve_tokens, + "input_cost_per_token": input_cost or None, + "output_cost_per_token": output_cost or None, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -300,6 +545,19 @@ class OpenRouterIntegrationService: # Shape: {model_name: {"gated": bool, "score": float | None}} self._health_cache: dict[str, dict[str, Any]] = {} self._enrich_task: asyncio.Task | None = None + # Raw OpenRouter pricing per model_id, captured at the same time + # we generate configs. Consumed by ``pricing_registration`` to + # teach LiteLLM the per-token cost of every dynamic deployment so + # the success-callback can populate ``response_cost`` correctly. + self._raw_pricing: dict[str, dict[str, str]] = {} + # Cached raw catalogue from the most recent fetch. Image / vision + # emitters reuse this to avoid a second network call per surface. + self._raw_models: list[dict] = [] + # Image / vision config caches (only populated when the matching + # opt-in flag is true on initialize). Refreshed in lockstep with + # the chat catalogue. + self._image_configs: list[dict] = [] + self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -329,8 +587,32 @@ class OpenRouterIntegrationService: self._initialized = True return [] + self._raw_models = raw_models self._configs = _generate_configs(raw_models, settings) self._configs_by_id = {c["id"]: c for c in self._configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + + # Populate image / vision caches when their opt-in flag is set. + # Empty otherwise so the accessors return [] without re-running + # filters every refresh. + if settings.get("image_generation_enabled"): + self._image_configs = _generate_image_gen_configs(raw_models, settings) + logger.info( + "OpenRouter integration: image-gen emission ON (%d models)", + len(self._image_configs), + ) + else: + self._image_configs = [] + + if settings.get("vision_enabled"): + self._vision_configs = _generate_vision_llm_configs(raw_models, settings) + logger.info( + "OpenRouter integration: vision LLM emission ON (%d models)", + len(self._vision_configs), + ) + else: + self._vision_configs = [] + self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -369,6 +651,8 @@ class OpenRouterIntegrationService: new_configs = _generate_configs(raw_models, self._settings) new_by_id = {c["id"]: c for c in new_configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + self._raw_models = raw_models from app.config import config as app_config @@ -382,6 +666,29 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id + # Image / vision lists are atomic-swapped the same way: filter out + # the previous dynamic entries from the live config list and append + # the freshly generated ones. No-ops when the opt-in flag is off. + if self._settings.get("image_generation_enabled"): + new_image = _generate_image_gen_configs(raw_models, self._settings) + static_image = [ + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image + self._image_configs = new_image + + if self._settings.get("vision_enabled"): + new_vision = _generate_vision_llm_configs(raw_models, self._settings) + static_vision = [ + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision + self._vision_configs = new_vision + # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -407,6 +714,21 @@ class OpenRouterIntegrationService: # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) + # Re-register LiteLLM pricing for the freshly fetched catalogue + # so newly added OR models bill correctly on their first call. + # Runs before the router rebuild because the router may issue + # cost-table lookups during deployment registration. + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as exc: + logger.warning( + "OpenRouter refresh: pricing re-registration skipped (%s)", exc + ) + # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay # out; a refresh also needs to pick up any static-config edits and @@ -635,3 +957,34 @@ class OpenRouterIntegrationService: def get_config_by_id(self, config_id: int) -> dict | None: return self._configs_by_id.get(config_id) + + def get_image_generation_configs(self) -> list[dict]: + """Return the dynamic OpenRouter image-generation configs (empty + list when the ``image_generation_enabled`` flag is off). + + Each entry already has ``billing_tier`` derived per-model from + OpenRouter's signals and is shaped to drop directly into + ``Config.GLOBAL_IMAGE_GEN_CONFIGS``. + """ + return list(self._image_configs) + + def get_vision_llm_configs(self) -> list[dict]: + """Return the dynamic OpenRouter vision-LLM configs (empty list + when the ``vision_enabled`` flag is off). + + Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` + so ``pricing_registration`` can teach LiteLLM the cost of these + models the same way it does for chat — which keeps the billable + wrapper able to debit accurate micro-USD on a vision call. + """ + return list(self._vision_configs) + + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + """Return the cached raw OpenRouter pricing map. + + Shape: ``{model_id: {"prompt": str, "completion": str}}``. The + values are the strings OpenRouter publishes (USD per token), + never converted to floats here so the caller can decide how to + handle malformed or unset entries. + """ + return dict(self._raw_pricing) diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py new file mode 100644 index 000000000..de98e50c2 --- /dev/null +++ b/surfsense_backend/app/services/pricing_registration.py @@ -0,0 +1,274 @@ +""" +Pricing registration with LiteLLM. + +Many models reach our LiteLLM callback without LiteLLM knowing their +per-token cost — namely: + +* The ~300 dynamic OpenRouter deployments (their pricing only lives on + OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published + pricing table). +* Static YAML deployments whose ``base_model`` name is operator-defined + (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore + not in LiteLLM's table either. + +Without registration, ``kwargs["response_cost"]`` is 0 for those calls +and the user gets billed nothing — a fail-safe but wrong answer for a +cost-based credit system. This module runs once at startup, after the +OpenRouter integration has fetched its catalogue, and registers each +known model's pricing with ``litellm.register_model()`` under multiple +plausible alias keys (LiteLLM's cost lookup may use any of them +depending on whether the call went through the Router, ChatLiteLLM, +or a direct ``acompletion``). + +Operators who run a custom Azure deployment whose ``base_model`` name +isn't in LiteLLM's table can declare per-token pricing inline in +``global_llm_config.yaml`` via ``input_cost_per_token`` and +``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without +that declaration the model's calls debit 0 — never overbilled. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import litellm + +logger = logging.getLogger(__name__) + + +def _safe_float(value: Any) -> float: + """Return ``float(value)`` if it parses to a positive number, else 0.0.""" + if value is None: + return 0.0 + try: + f = float(value) + except (TypeError, ValueError): + return 0.0 + return f if f > 0 else 0.0 + + +def _alias_set_for_openrouter(model_id: str) -> list[str]: + """Return the alias keys to register an OpenRouter model under. + + LiteLLM's cost-callback lookup key varies by call path: + - Router with ``model="openrouter/X"`` → kwargs["model"] is + typically ``openrouter/X``. + - LiteLLM's own provider routing may strip the prefix and pass the + bare ``X`` to the cost-table lookup. + Registering under both keeps the lookup hermetic regardless of + which path the call took. + """ + aliases = [f"openrouter/{model_id}", model_id] + return list(dict.fromkeys(a for a in aliases if a)) + + +def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]: + """Return the alias keys to register a static YAML deployment under. + + Same reasoning as the OpenRouter set: cover the bare ``base_model``, + the ``/`` form LiteLLM Router constructs, and the + bare ``model_name`` because callbacks sometimes see whichever was + configured first. + """ + provider_lower = (provider or "").lower() + aliases: list[str] = [] + if base_model: + aliases.append(base_model) + if provider_lower and base_model: + aliases.append(f"{provider_lower}/{base_model}") + if model_name and model_name != base_model: + aliases.append(model_name) + if provider_lower and model_name and model_name != base_model: + aliases.append(f"{provider_lower}/{model_name}") + # Azure deployments often surface as "azure/"; normalise the + # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``. + if provider_lower == "azure_openai": + if base_model: + aliases.append(f"azure/{base_model}") + if model_name and model_name != base_model: + aliases.append(f"azure/{model_name}") + return list(dict.fromkeys(a for a in aliases if a)) + + +def _register( + aliases: list[str], + *, + input_cost: float, + output_cost: float, + provider: str, + mode: str = "chat", +) -> int: + """Register a single pricing entry under every alias in ``aliases``. + + Returns the count of aliases successfully registered. + """ + payload: dict[str, dict[str, Any]] = {} + for alias in aliases: + payload[alias] = { + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost, + "litellm_provider": provider, + "mode": mode, + } + if not payload: + return 0 + try: + litellm.register_model(payload) + except Exception as exc: + logger.warning( + "[PricingRegistration] register_model failed for aliases=%s: %s", + aliases, + exc, + ) + return 0 + return len(payload) + + +def _register_chat_shape_configs( + configs: list[dict], + *, + or_pricing: dict[str, dict[str, str]], + label: str, +) -> tuple[int, int, int, list[str]]: + """Common loop that registers per-token pricing for a list of "chat-shape" + configs (chat or vision LLM — both use ``input_cost_per_token`` / + ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape). + + Returns ``(registered_models, registered_aliases, skipped, sample_keys)``. + """ + registered_models = 0 + registered_aliases = 0 + skipped_no_pricing = 0 + sample_keys: list[str] = [] + + for cfg in configs: + provider = str(cfg.get("provider") or "").upper() + model_name = str(cfg.get("model_name") or "").strip() + litellm_params = cfg.get("litellm_params") or {} + base_model = str(litellm_params.get("base_model") or model_name).strip() + + if provider == "OPENROUTER": + entry = or_pricing.get(model_name) + if entry: + input_cost = _safe_float(entry.get("prompt")) + output_cost = _safe_float(entry.get("completion")) + else: + # Vision configs from ``_generate_vision_llm_configs`` + # carry their pricing inline because the OpenRouter + # raw-pricing cache is keyed by chat-catalogue model_id; + # vision flows pick up the inline values here. + input_cost = _safe_float(cfg.get("input_cost_per_token")) + output_cost = _safe_float(cfg.get("output_cost_per_token")) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_openrouter(model_name) + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider="openrouter", + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + continue + + input_cost = _safe_float( + cfg.get("input_cost_per_token") + or litellm_params.get("input_cost_per_token") + ) + output_cost = _safe_float( + cfg.get("output_cost_per_token") + or litellm_params.get("output_cost_per_token") + ) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_yaml(provider, model_name, base_model) + provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider=provider_slug, + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + + logger.info( + "[PricingRegistration:%s] registered pricing for %d models (%d aliases); " + "%d configs had no pricing data; sample registered keys=%s", + label, + registered_models, + registered_aliases, + skipped_no_pricing, + sample_keys, + ) + return registered_models, registered_aliases, skipped_no_pricing, sample_keys + + +def register_pricing_from_global_configs() -> None: + """Register pricing for every known LLM deployment with LiteLLM. + + Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` + so vision calls (during indexing) can resolve cost the same way chat + calls do — namely: + + 1. ``OPENROUTER``: pulls the cached raw pricing from + ``OpenRouterIntegrationService`` (populated during its own + startup fetch) and converts the per-token strings to floats. For + vision configs that carry pricing inline (``input_cost_per_token`` / + ``output_cost_per_token`` set on the cfg itself) we fall back to + those values when the OR cache misses the model. + 2. Anything else: looks for operator-declared + ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML + config block (top-level or nested under ``litellm_params``). + + **Image generation is intentionally NOT registered here.** The cost + shape for image-gen is per-image (``output_cost_per_image``), not + per-token, and LiteLLM's ``register_model`` doesn't accept those + keys via the chat-cost path. OpenRouter image-gen models populate + ``response_cost`` directly from their response header instead, and + Azure-native image-gen models are already in LiteLLM's cost map. + + Calls without a resolved pair of costs are skipped, not registered + with zeros — operators who forget pricing get a "$0 debit" warning + in ``TokenTrackingCallback`` rather than silently overwriting any + pricing LiteLLM might know natively. + """ + from app.config import config as app_config + + chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) + vision_configs: list[dict] = list( + getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] + ) + if not chat_configs and not vision_configs: + logger.info("[PricingRegistration] no global configs to register") + return + + or_pricing: dict[str, dict[str, str]] = {} + try: + from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, + ) + + if OpenRouterIntegrationService.is_initialized(): + or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing() + except Exception as exc: + logger.debug( + "[PricingRegistration] OpenRouter pricing not available yet: %s", exc + ) + + if chat_configs: + _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") + if vision_configs: + _register_chat_shape_configs( + vision_configs, or_pricing=or_pricing, label="vision" + ) diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py new file mode 100644 index 000000000..dca1f9462 --- /dev/null +++ b/surfsense_backend/app/services/provider_api_base.py @@ -0,0 +1,106 @@ +"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. + +LiteLLM falls back to the module-global ``litellm.api_base`` when an +individual call doesn't pass one, which silently inherits provider-agnostic +env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an +explicit ``api_base``, an ``openrouter/`` request can end up at an +Azure endpoint and 404 with ``Resource not found`` (real reproducer: +[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends +``/chat/completions`` to whatever inherited base it gets, regardless of +provider). + +The chat router has had this defense for a while +(``llm_router_service.py:466-478``). This module hoists the maps + cascade +into a tiny standalone helper so vision and image-gen can share the same +source of truth without an inter-service circular import. +""" + +from __future__ import annotations + +PROVIDER_DEFAULT_API_BASE: dict[str, str] = { + "openrouter": "https://openrouter.ai/api/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "perplexity": "https://api.perplexity.ai", + "xai": "https://api.x.ai/v1", + "cerebras": "https://api.cerebras.ai/v1", + "deepinfra": "https://api.deepinfra.com/v1/openai", + "fireworks_ai": "https://api.fireworks.ai/inference/v1", + "together_ai": "https://api.together.xyz/v1", + "anyscale": "https://api.endpoints.anyscale.com/v1", + "cometapi": "https://api.cometapi.com/v1", + "sambanova": "https://api.sambanova.ai/v1", +} +"""Default ``api_base`` per LiteLLM provider prefix (lowercase). + +Only providers with a well-known, stable public base URL are listed — +self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, +huggingface, databricks, cloudflare, replicate) are intentionally omitted +so their existing config-driven behaviour is preserved.""" + + +PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { + "DEEPSEEK": "https://api.deepseek.com/v1", + "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "MOONSHOT": "https://api.moonshot.ai/v1", + "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", + "MINIMAX": "https://api.minimax.io/v1", +} +"""Canonical provider key (uppercase) → base URL. + +Used when the LiteLLM provider prefix is the generic ``openai`` shim but the +config's ``provider`` field tells us which API it actually is (DeepSeek, +Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each +has its own base URL).""" + + +def resolve_api_base( + *, + provider: str | None, + provider_prefix: str | None, + config_api_base: str | None, +) -> str | None: + """Resolve a non-Azure-leaking ``api_base`` for a deployment. + + Cascade (first non-empty wins): + 1. The config's own ``api_base`` (whitespace-only treated as missing). + 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. + 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. + 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM + provider integration apply its own default (e.g. AzureOpenAI's + deployment-derived URL, custom provider's per-deployment URL). + + Args: + provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, + ``"DEEPSEEK"``). Case-insensitive. + provider_prefix: The LiteLLM model-string prefix the same call + site builds for the model id (e.g. ``"openrouter"``, + ``"groq"``). Case-insensitive. + config_api_base: ``api_base`` from the global YAML / DB row / + OpenRouter dynamic config. Empty / whitespace-only means + "missing" — the resolver still applies the cascade. + + Returns: + A URL string, or ``None`` if no default applies for this provider. + """ + if config_api_base and config_api_base.strip(): + return config_api_base + + if provider: + key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) + if key_default: + return key_default + + if provider_prefix: + prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) + if prefix_default: + return prefix_default + + return None + + +__all__ = [ + "PROVIDER_DEFAULT_API_BASE", + "PROVIDER_KEY_DEFAULT_API_BASE", + "resolve_api_base", +] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py new file mode 100644 index 000000000..e9a1c33e1 --- /dev/null +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -0,0 +1,280 @@ +"""Capability resolution shared by chat / image / vision call sites. + +Why this exists +--------------- +The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a +single, authoritative answer to one question: *can this chat config accept +``image_url`` content blocks?* Without it, the new-chat selector can't badge +incompatible models and the streaming task can't fail fast with a friendly +error before sending an image to a text-only provider. + +Two functions, two intents: + +- :func:`derive_supports_image_input` — best-effort *True* for catalog and + UI surfacing. Default-allow: an unknown / unmapped model is treated as + capable so we never lock the user out of a freshly added or + third-party-hosted vision model. + +- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming + task's safety net. Returns True only when LiteLLM's model map *explicitly* + sets ``supports_vision=False`` (or its bare-name variant does). Anything + else — missing key, lookup exception, ``supports_vision=True`` — returns + False so the request flows through to the provider. + +Implementation rule: only public LiteLLM symbols +------------------------------------------------ +``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the +typed module surface (see ``litellm.__init__`` lazy stubs) and are stable +across releases. The private ``_is_explicitly_disabled_factory`` and +``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade +can't silently break us. + +Why the previous round's strict YAML opt-in flag failed +------------------------------------------------------- +``supports_image_input: false`` was the YAML loader's setdefault. Operators +maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI +YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to +False and the streaming gate rejected every image turn. Sourcing capability +from LiteLLM's authoritative model map (which already says +``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable + +import litellm + +logger = logging.getLogger(__name__) + + +# Provider-name → LiteLLM model-prefix map. +# +# Owned here because ``app.services.provider_capabilities`` is the +# only edge that's safe to call from ``app.config``'s YAML loader at +# class-body init time. ``app.agents.new_chat.llm_config`` re-exports +# this constant under the historical ``PROVIDER_MAP`` name; placing the +# map there directly would re-introduce the +# ``app.config -> ... -> app.agents.new_chat.tools.generate_image -> +# app.config`` cycle that prompted the move. +_PROVIDER_PREFIX_MAP: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "XAI": "xai", + "BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "COMETAPI": "cometapi", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + +def _candidate_model_strings( + *, + provider: str | None, + model_name: str | None, + base_model: str | None, + custom_provider: str | None, +) -> list[tuple[str, str | None]]: + """Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates. + + LiteLLM's capability lookup is keyed by ``model`` + (optional) + ``custom_llm_provider``. Different config sources give us different + levels of detail, so we try the most-specific keys first and fall back + to bare model names so unannotated entries (e.g. an Azure deployment + pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the + map. Order matters — the first lookup that returns a definitive answer + wins for both helpers. + """ + candidates: list[tuple[str, str | None]] = [] + seen: set[tuple[str, str | None]] = set() + + def _add(model: str | None, llm_provider: str | None) -> None: + if not model: + return + key = (model, llm_provider) + if key in seen: + return + seen.add(key) + candidates.append(key) + + provider_prefix: str | None = None + if provider: + provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) + if custom_provider: + # ``custom_provider`` overrides everything for CUSTOM/proxy setups. + provider_prefix = custom_provider + + primary_model = base_model or model_name + bare_model = model_name + + # Most-specific first: provider-prefixed identifier with explicit + # custom_llm_provider so LiteLLM won't have to guess the provider via + # ``get_llm_provider``. + if primary_model and provider_prefix: + # e.g. "azure/gpt-5.4" + custom_llm_provider="azure" + if "/" in primary_model: + _add(primary_model, provider_prefix) + else: + _add(f"{provider_prefix}/{primary_model}", provider_prefix) + + # Bare base_model (or model_name) with provider hint — handles entries + # the upstream map keys without a provider prefix (most ``gpt-*`` and + # ``claude-*`` entries do this). + if primary_model: + _add(primary_model, provider_prefix) + + # Fallback to model_name when base_model differs (e.g. an Azure + # deployment whose model_name is the deployment id but base_model is the + # canonical OpenAI sku). + if bare_model and bare_model != primary_model: + if provider_prefix and "/" not in bare_model: + _add(f"{provider_prefix}/{bare_model}", provider_prefix) + _add(bare_model, provider_prefix) + _add(bare_model, None) + + return candidates + + +def derive_supports_image_input( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, + openrouter_input_modalities: Iterable[str] | None = None, +) -> bool: + """Best-effort capability flag for the new-chat selector and catalog. + + Resolution order (first definitive answer wins): + + 1. ``openrouter_input_modalities`` (when provided as a non-empty + iterable). OpenRouter exposes ``architecture.input_modalities`` per + model and that's the authoritative source for OR dynamic configs. + 2. ``litellm.supports_vision`` against each candidate identifier from + :func:`_candidate_model_strings`. Returns True as soon as any + candidate confirms vision support. + 3. Default ``True`` — the conservative-allow stance. An unknown / + newly-added / third-party-hosted model is *not* pre-judged. The + streaming safety net (:func:`is_known_text_only_chat_model`) is the + only place a False ever blocks; everywhere else, a False here would + just hide a usable model from the user. + + Returns: + True if the model can plausibly accept image input, False only when + OpenRouter explicitly says it can't. + """ + if openrouter_input_modalities is not None: + modalities = list(openrouter_input_modalities) + if modalities: + return "image" in modalities + # Empty list explicitly published by OR — treat as "no image". + return False + + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + if litellm.supports_vision( + model=model_string, custom_llm_provider=custom_llm_provider + ): + return True + except Exception as exc: + logger.debug( + "litellm.supports_vision raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # Default-allow. ``is_known_text_only_chat_model`` is the strict gate. + return True + + +def is_known_text_only_chat_model( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, +) -> bool: + """Strict opt-out probe for the streaming-task safety net. + + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False`` for at least one candidate identifier. Missing + key, lookup exception, or ``supports_vision=True`` all return False so + the streaming task lets the request through. This is the inverse-default + of :func:`derive_supports_image_input`. + + Why two functions + ----------------- + The selector wants "show me everything that's plausibly capable" — + default-allow. The safety net wants "block only when I'm certain it + can't" — default-pass. Mixing the two intents in a single function + leads to the regression we're fixing here. + """ + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + info = litellm.get_model_info( + model=model_string, custom_llm_provider=custom_llm_provider + ) + except Exception as exc: + logger.debug( + "litellm.get_model_info raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision`` + # may be missing, None, True, or False. We only fire on explicit + # False — None / missing / True all mean "don't block". + try: + value = info.get("supports_vision") # type: ignore[union-attr] + except AttributeError: + value = None + if value is False: + return True + + return False + + +__all__ = [ + "derive_supports_image_input", + "is_known_text_only_chat_model", +] diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py new file mode 100644 index 000000000..0040e5a5b --- /dev/null +++ b/surfsense_backend/app/services/quota_checked_vision_llm.py @@ -0,0 +1,105 @@ +""" +Vision LLM proxy that enforces premium credit quota on every ``ainvoke``. + +Used by :func:`app.services.llm_service.get_vision_llm` so callers in the +indexing pipeline (file processors, connector indexers, etl pipeline) can +keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)`` +— without threading ``user_id`` through every parser. The wrapper looks like +a chat model from the outside; on the inside it routes each call through +``billable_call`` so the user's premium credit pool is reserved → finalized +or released, and a ``TokenUsage`` audit row is written. + +Free configs are returned unwrapped from ``get_vision_llm`` (they do not +need quota enforcement) so this class only ever wraps premium configs. + +Why a wrapper instead of plumbing ``user_id`` through every caller: + +* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive, + Dropbox, local-folder, file-processor, ETL pipeline) each calling + ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is + invasive, error-prone, and easy for a future indexer to forget. +* Per the design (issue M), we always debit the *search-space owner*, not + the triggering user, so ``user_id`` is fully derivable from the search + space the caller is already operating on. The wrapper captures it once + at construction time. +* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each + call run this coroutine"; subclassing isn't safe across versions because + it derives from ``BaseChatModel`` which expects specific Pydantic shapes. + Composition via attribute proxying (``__getattr__``) is robust to + upstream changes — every method other than ``ainvoke`` falls through to + the inner LLM unchanged. +""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from app.services.billable_calls import QuotaInsufficientError, billable_call + +logger = logging.getLogger(__name__) + + +class QuotaCheckedVisionLLM: + """Composition wrapper around a langchain chat model that enforces + premium credit quota on every ``ainvoke``. + + Anything other than ``ainvoke`` is forwarded to the inner model so + ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all + still work — they simply bypass quota enforcement, which is fine + because the indexing pipeline only ever calls ``ainvoke`` today. + """ + + def __init__( + self, + inner_llm: Any, + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None, + usage_type: str = "vision_extraction", + ) -> None: + self._inner = inner_llm + self._user_id = user_id + self._search_space_id = search_space_id + self._billing_tier = billing_tier + self._base_model = base_model + self._quota_reserve_tokens = quota_reserve_tokens + self._usage_type = usage_type + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + """Proxied async invoke that runs the underlying call inside + ``billable_call``. + + Raises: + QuotaInsufficientError: when the user has exhausted their + premium credit pool. Caller (``etl_pipeline_service._extract_image``) + catches this and falls back to the document parser. + """ + async with billable_call( + user_id=self._user_id, + search_space_id=self._search_space_id, + billing_tier=self._billing_tier, + base_model=self._base_model, + quota_reserve_tokens=self._quota_reserve_tokens, + usage_type=self._usage_type, + call_details={"model": self._base_model}, + ): + return await self._inner.ainvoke(input, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Forward everything else (``invoke``, ``astream``, ``bind``, + ``with_structured_output``, …) to the inner model. + + ``__getattr__`` is only consulted when the attribute is *not* + already found on the proxy, which is exactly the contract we + want — methods we override stay on the proxy, the rest fall + through. + """ + return getattr(self._inner, name) + + +__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"] diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py index a3ec7aed0..310c3eb5e 100644 --- a/surfsense_backend/app/services/token_quota_service.py +++ b/surfsense_backend/app/services/token_quota_service.py @@ -22,6 +22,71 @@ from app.config import config logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Per-call reservation estimator (USD micro-units) +# --------------------------------------------------------------------------- + +# Minimum reserve in micros so a user with $0.0001 left can still make a tiny +# request, and so models without registered pricing reserve at least +# something while the call runs (debited 0 at finalize anyway when their +# cost can't be resolved). +_QUOTA_MIN_RESERVE_MICROS = 100 + + +def estimate_call_reserve_micros( + *, + base_model: str, + quota_reserve_tokens: int | None, +) -> int: + """Return the number of micro-USD to reserve for one premium call. + + Computes a worst-case upper bound from LiteLLM's per-token pricing + table: + + reserve_usd ≈ reserve_tokens x (input_cost + output_cost) + + so the math scales with model cost — Claude Opus + 4K reserve_tokens + naturally reserves ≈ $0.36, while a cheap model reserves only a few + cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]`` + so a misconfigured "$1000/M" model can't lock the whole balance on + one call. + + If ``litellm.get_model_info`` raises (model unknown) we fall back to + the floor — 100 micros / $0.0001 — which is enough to gate a sane + request without over-reserving for a model whose pricing the + operator hasn't declared yet. + """ + reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL + if reserve_tokens <= 0: + reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL + + try: + from litellm import get_model_info + + info = get_model_info(base_model) if base_model else {} + input_cost = float(info.get("input_cost_per_token") or 0.0) + output_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception as exc: + logger.debug( + "[quota_reserve] cost lookup failed for base_model=%s: %s", + base_model, + exc, + ) + input_cost = 0.0 + output_cost = 0.0 + + if input_cost == 0.0 and output_cost == 0.0: + return _QUOTA_MIN_RESERVE_MICROS + + reserve_usd = reserve_tokens * (input_cost + output_cost) + reserve_micros = round(reserve_usd * 1_000_000) + if reserve_micros < _QUOTA_MIN_RESERVE_MICROS: + reserve_micros = _QUOTA_MIN_RESERVE_MICROS + if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS: + reserve_micros = config.QUOTA_MAX_RESERVE_MICROS + return reserve_micros + + class QuotaScope(StrEnum): ANONYMOUS = "anonymous" PREMIUM = "premium" @@ -444,8 +509,16 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - reserve_tokens: int, + reserve_micros: int, ) -> QuotaResult: + """Reserve ``reserve_micros`` (USD micro-units) from the user's + premium credit balance. + + ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are + all in micro-USD on this code path; callers (chat stream, + token-status route, FE display) convert to dollars by dividing + by 1_000_000. + """ from app.db import User user = ( @@ -465,11 +538,11 @@ class TokenQuotaService: limit=0, ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved - effective = used + reserved + reserve_tokens + effective = used + reserved + reserve_micros if effective > limit: remaining = max(0, limit - used - reserved) await db_session.rollback() @@ -482,10 +555,10 @@ class TokenQuotaService: remaining=remaining, ) - user.premium_tokens_reserved = reserved + reserve_tokens + user.premium_credit_micros_reserved = reserved + reserve_micros await db_session.commit() - new_reserved = reserved + reserve_tokens + new_reserved = reserved + reserve_micros remaining = max(0, limit - used - new_reserved) warning_threshold = int(limit * 0.8) @@ -510,9 +583,12 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - actual_tokens: int, - reserved_tokens: int, + actual_micros: int, + reserved_micros: int, ) -> QuotaResult: + """Settle the reservation: release ``reserved_micros`` and debit + ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD). + """ from app.db import User user = ( @@ -529,16 +605,18 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros + ) + user.premium_credit_micros_used = ( + user.premium_credit_micros_used + actual_micros ) - user.premium_tokens_used = user.premium_tokens_used + actual_tokens await db_session.commit() - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) @@ -562,8 +640,13 @@ class TokenQuotaService: async def premium_release( db_session: AsyncSession, user_id: Any, - reserved_tokens: int, + reserved_micros: int, ) -> None: + """Release ``reserved_micros`` previously held by ``premium_reserve``. + + Used when a request fails before finalize (so the reservation + doesn't leak credit). + """ from app.db import User user = ( @@ -576,8 +659,8 @@ class TokenQuotaService: .scalar_one_or_none() ) if user is not None: - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros ) await db_session.commit() @@ -598,9 +681,9 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 9aa8c6e70..9406d9be4 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -16,11 +16,14 @@ from __future__ import annotations import dataclasses import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from contextvars import ContextVar from dataclasses import dataclass, field from typing import Any from uuid import UUID +import litellm from litellm.integrations.custom_logger import CustomLogger from sqlalchemy.ext.asyncio import AsyncSession @@ -35,6 +38,8 @@ class TokenCallRecord: prompt_tokens: int completion_tokens: int total_tokens: int + cost_micros: int = 0 + call_kind: str = "chat" @dataclass @@ -49,6 +54,8 @@ class TurnTokenAccumulator: prompt_tokens: int, completion_tokens: int, total_tokens: int, + cost_micros: int = 0, + call_kind: str = "chat", ) -> None: self.calls.append( TokenCallRecord( @@ -56,20 +63,28 @@ class TurnTokenAccumulator: prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) ) def per_message_summary(self) -> dict[str, dict[str, int]]: - """Return token counts grouped by model name.""" + """Return token counts (and cost) grouped by model name.""" by_model: dict[str, dict[str, int]] = {} for c in self.calls: entry = by_model.setdefault( c.model, - {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost_micros": 0, + }, ) entry["prompt_tokens"] += c.prompt_tokens entry["completion_tokens"] += c.completion_tokens entry["total_tokens"] += c.total_tokens + entry["cost_micros"] += c.cost_micros return by_model @property @@ -84,6 +99,21 @@ class TurnTokenAccumulator: def total_completion_tokens(self) -> int: return sum(c.completion_tokens for c in self.calls) + @property + def total_cost_micros(self) -> int: + """Sum of per-call ``cost_micros`` across the entire turn. + + Used by ``stream_new_chat`` to debit a premium turn's actual + provider cost (in micro-USD) from the user's premium credit + balance. ``cost_micros`` per call is captured by + ``TokenTrackingCallback.async_log_success_event`` from + ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost), + with multiple fallback paths so OpenRouter dynamic models and + custom Azure deployments still bill correctly when our + ``pricing_registration`` ran at startup. + """ + return sum(c.cost_micros for c in self.calls) + def serialized_calls(self) -> list[dict[str, Any]]: return [dataclasses.asdict(c) for c in self.calls] @@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar( def start_turn() -> TurnTokenAccumulator: - """Create a fresh accumulator for the current async context and return it.""" + """Create a fresh accumulator for the current async context and return it. + + NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For + short-lived per-call billable wrappers (image generation REST endpoint, + vision LLM during indexing) prefer :func:`scoped_turn`, which uses a + ContextVar reset token to restore the *previous* accumulator on exit and + avoids leaking call records across reservations (issue B). + """ acc = TurnTokenAccumulator() _turn_accumulator.set(acc) logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc)) @@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() +@asynccontextmanager +async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: + """Async context manager that scopes a fresh ``TurnTokenAccumulator`` + for the duration of the ``async with`` block, then *resets* the + ContextVar to its previous value on exit. + + This is the safe primitive for per-call billable operations + (image generation, vision LLM extraction, podcasts) that may run + inside an outer chat turn or be called sequentially from the same + background worker. Using ``ContextVar.set`` without ``reset`` (as + :func:`start_turn` does) would leak the inner accumulator into the + outer scope, causing the outer chat turn to debit cost twice. + + Usage:: + + async with scoped_turn() as acc: + await llm.ainvoke(...) + # acc.total_cost_micros captures cost from the LiteLLM callback + # Outer accumulator (if any) is restored here. + """ + acc = TurnTokenAccumulator() + token = _turn_accumulator.set(acc) + logger.debug( + "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)", + id(acc), + token, + ) + try: + yield acc + finally: + _turn_accumulator.reset(token) + logger.debug( + "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)", + id(acc), + len(acc.calls), + acc.total_cost_micros, + ) + + +def _extract_cost_usd( + kwargs: dict[str, Any], + response_obj: Any, + model: str, + prompt_tokens: int, + completion_tokens: int, + is_image: bool = False, +) -> float: + """Best-effort USD cost extraction for a single LLM/image call. + + Tries four sources in priority order and returns the first that + yields a positive number; returns 0.0 if all four fail (the call + will then debit nothing from the user's balance — fail-safe). + + Sources: + 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback + field, populated for ``Router.acompletion`` since PR #12500. + 2. ``response_obj._hidden_params["response_cost"]`` — same value + exposed on the response itself. + 3. ``litellm.completion_cost(completion_response=response_obj)`` + — recompute from the response and LiteLLM's pricing table. + 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)`` + — manual fallback for OpenRouter/custom-Azure models that + only resolve via aliases registered by + ``pricing_registration`` at startup. **Skipped for image + responses** — ``cost_per_token`` does not support ``ImageResponse`` + and would raise; the cost map for image-gen lives in different + keys (``output_cost_per_image``) handled by ``completion_cost``. + """ + cost = kwargs.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + hidden = getattr(response_obj, "_hidden_params", None) or {} + if isinstance(hidden, dict): + cost = hidden.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + try: + value = float(litellm.completion_cost(completion_response=response_obj)) + if value > 0: + return value + except Exception as exc: + if is_image: + # Image-gen path: OpenRouter's image responses can omit + # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator`` + # then *raises* (no cost map for OpenRouter image models). + # Bail out with a warning rather than falling through to + # cost_per_token (which is also incompatible with ImageResponse). + logger.warning( + "[TokenTracking] completion_cost failed for image model=%s " + "(provider may have omitted usage.cost). Debiting 0. " + "Cause: %s", + model, + exc, + ) + return 0.0 + logger.debug( + "[TokenTracking] completion_cost failed for model=%s: %s", model, exc + ) + + if is_image: + # Never call cost_per_token for ImageResponse — keys mismatch and + # the function is documented chat-only. + return 0.0 + + if model and (prompt_tokens > 0 or completion_tokens > 0): + try: + prompt_cost, completion_cost = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + value = float(prompt_cost) + float(completion_cost) + if value > 0: + return value + except Exception as exc: + logger.debug( + "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc + ) + + return 0.0 + + class TokenTrackingCallback(CustomLogger): """LiteLLM callback that captures token usage into the turn accumulator.""" @@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger): ) return + # Detect image generation responses — they have a different usage + # shape (ImageUsage with input_tokens/output_tokens) and require a + # different cost-extraction path. We probe by class name to avoid a + # hard import dependency on litellm internals. + response_cls = type(response_obj).__name__ + is_image = response_cls == "ImageResponse" + usage = getattr(response_obj, "usage", None) if not usage: logger.debug( @@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger): ) return - prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 - completion_tokens = getattr(usage, "completion_tokens", 0) or 0 - total_tokens = getattr(usage, "total_tokens", 0) or 0 + if is_image: + # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens`` + # (not prompt_tokens/completion_tokens). Several providers + # populate only one or neither (e.g. OpenRouter's gpt-image-1 + # passes through `input_tokens` from the prompt but no + # completion); fall through gracefully to 0. + prompt_tokens = getattr(usage, "input_tokens", 0) or 0 + completion_tokens = getattr(usage, "output_tokens", 0) or 0 + total_tokens = ( + getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens + ) + call_kind = "image_generation" + else: + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + total_tokens = getattr(usage, "total_tokens", 0) or 0 + call_kind = "chat" model = kwargs.get("model", "unknown") + cost_usd = _extract_cost_usd( + kwargs=kwargs, + response_obj=response_obj, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_image=is_image, + ) + cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0 + + if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0): + logger.warning( + "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d " + "kind=%s — debiting 0. Register pricing via pricing_registration or YAML " + "input_cost_per_token/output_cost_per_token (or rely on response_cost " + "for image generation).", + model, + prompt_tokens, + completion_tokens, + call_kind, + ) + acc.add( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) logger.info( - "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", + "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d " + "cost=$%.6f (%d micros) (accumulator now has %d calls)", model, + call_kind, prompt_tokens, completion_tokens, total_tokens, + cost_usd, + cost_micros, len(acc.calls), ) @@ -168,6 +388,7 @@ async def record_token_usage( prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0, + cost_micros: int = 0, model_breakdown: dict[str, Any] | None = None, call_details: dict[str, Any] | None = None, thread_id: int | None = None, @@ -185,6 +406,7 @@ async def record_token_usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, model_breakdown=model_breakdown, call_details=call_details, thread_id=thread_id, @@ -194,11 +416,12 @@ async def record_token_usage( ) session.add(record) logger.debug( - "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", + "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d", usage_type, prompt_tokens, completion_tokens, total_tokens, + cost_micros, ) return record except Exception: diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0d782ab2b..ed5de921c 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,6 +3,8 @@ from typing import Any from litellm import Router +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 @@ -108,10 +110,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" else: - provider = config.get("provider", "").upper() provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{config['model_name']}" @@ -120,8 +123,13 @@ class VisionLLMRouterService: "api_key": config.get("api_key"), } - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base if config.get("api_version"): litellm_params["api_version"] = config["api_version"] diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index 5b1f2cd13..b23359f36 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -1,10 +1,25 @@ -"""Celery tasks package.""" +"""Celery tasks package. + +Also hosts the small helpers every async celery task should use to +spin up its event loop. See :func:`run_async_celery_task` for the +canonical pattern. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Awaitable, Callable +from typing import TypeVar from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool from app.config import config +logger = logging.getLogger(__name__) + _celery_engine = None _celery_session_maker = None @@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker: _celery_engine, expire_on_commit=False ) return _celery_session_maker + + +def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None: + """Drop the shared ``app.db.engine`` connection pool synchronously. + + The shared engine (used by ``shielded_async_session`` and most + routes / services) is a module-level singleton with a real pool. + Each celery task creates a fresh ``asyncio`` event loop; asyncpg + connections cache a reference to whichever loop opened them. When + a subsequent task's loop pulls a stale connection from the pool, + SQLAlchemy's ``pool_pre_ping`` checkout crashes with:: + + AttributeError: 'NoneType' object has no attribute 'send' + File ".../asyncio/proactor_events.py", line 402, in _loop_writing + self._write_fut = self._loop._proactor.send(self._sock, data) + + or hangs forever inside the asyncpg ``Connection._cancel`` cleanup + coroutine that can never run because its loop is gone. + + Disposing the engine forces the pool to drop every cached + connection so the next checkout opens a fresh one on the current + loop. Safe to call from a task's finally block; failure is logged + but never propagated. + """ + try: + from app.db import engine as shared_engine + + loop.run_until_complete(shared_engine.dispose()) + except Exception: + logger.warning("Shared DB engine dispose() failed", exc_info=True) + + +T = TypeVar("T") + + +def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T: + """Run an async coroutine inside a fresh event loop with proper + DB-engine cleanup. + + This is the canonical entry point for every async celery task. + It performs three responsibilities that were previously copy-pasted + (incorrectly) across each task module: + + 1. Create a fresh ``asyncio`` loop and install it on the current + thread (celery's ``--pool=solo`` runs every task on the main + thread, but other pool types don't). + 2. Dispose the shared ``app.db.engine`` BEFORE the task runs so + any stale connections left over from a previous task's loop + are dropped — defends against tasks that crashed without + cleaning up. + 3. Dispose the shared engine AFTER the task runs so the + connections we opened on this loop are released before the + loop closes (avoids ``coroutine 'Connection._cancel' was + never awaited`` warnings and the next-task hang). + + Use as:: + + @celery_app.task(name="my_task", bind=True) + def my_task(self, *args): + return run_async_celery_task(lambda: _my_task_impl(*args)) + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Defense-in-depth: prior task may have crashed before + # disposing. Idempotent — no-op if pool is already empty. + _dispose_shared_db_engine(loop) + return loop.run_until_complete(coro_factory()) + finally: + # Drop any connections this task opened so they don't leak + # into the next task's loop. + _dispose_shared_db_engine(loop) + with contextlib.suppress(Exception): + loop.run_until_complete(loop.shutdown_asyncgens()) + with contextlib.suppress(Exception): + asyncio.set_event_loop(None) + loop.close() + + +__all__ = [ + "get_celery_session_maker", + "run_async_celery_task", +] diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index fe1ac19d3..08d96cfa0 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -4,7 +4,7 @@ import logging import traceback from app.celery_app import celery_app -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -49,22 +49,15 @@ def index_notion_pages_task( end_date: str, ): """Celery task to index Notion pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_notion_pages( + return run_async_celery_task( + lambda: _index_notion_pages( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_notion_pages", connector_id) raise - finally: - loop.close() async def _index_notion_pages( @@ -95,19 +88,11 @@ def index_github_repos_task( end_date: str, ): """Celery task to index GitHub repositories.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_github_repos( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_github_repos( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_github_repos( @@ -138,19 +123,11 @@ def index_confluence_pages_task( end_date: str, ): """Celery task to index Confluence pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_confluence_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_confluence_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_confluence_pages( @@ -181,22 +158,15 @@ def index_google_calendar_events_task( end_date: str, ): """Celery task to index Google Calendar events.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_google_calendar_events( + return run_async_celery_task( + lambda: _index_google_calendar_events( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_google_calendar_events", connector_id) raise - finally: - loop.close() async def _index_google_calendar_events( @@ -227,19 +197,11 @@ def index_google_gmail_messages_task( end_date: str, ): """Celery task to index Google Gmail messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_gmail_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_google_gmail_messages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_google_gmail_messages( @@ -269,22 +231,14 @@ def index_google_drive_files_task( items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options' ): """Celery task to index Google Drive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_drive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_google_drive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_google_drive_files( @@ -317,22 +271,14 @@ def index_onedrive_files_task( items_dict: dict, ): """Celery task to index OneDrive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_onedrive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_onedrive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_onedrive_files( @@ -365,22 +311,14 @@ def index_dropbox_files_task( items_dict: dict, ): """Celery task to index Dropbox folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_dropbox_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_dropbox_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_dropbox_files( @@ -414,19 +352,11 @@ def index_elasticsearch_documents_task( end_date: str, ): """Celery task to index Elasticsearch documents.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_elasticsearch_documents( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_elasticsearch_documents( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_elasticsearch_documents( @@ -457,22 +387,15 @@ def index_crawled_urls_task( end_date: str, ): """Celery task to index Web page Urls.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_crawled_urls( + return run_async_celery_task( + lambda: _index_crawled_urls( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_crawled_urls", connector_id) raise - finally: - loop.close() async def _index_crawled_urls( @@ -503,19 +426,11 @@ def index_bookstack_pages_task( end_date: str, ): """Celery task to index BookStack pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_bookstack_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_bookstack_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_bookstack_pages( @@ -546,19 +461,11 @@ def index_composio_connector_task( end_date: str | None, ): """Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio).""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_composio_connector( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_composio_connector( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_composio_connector( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py index c2dbe7700..5d6bde6c1 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py @@ -11,7 +11,7 @@ from app.db import Document from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str): document_id: ID of document to reindex user_id: ID of user who edited the document """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reindex_document(document_id, user_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _reindex_document(document_id, user_id)) async def _reindex_document(document_id: int, user_id: str): diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 9d12f91f6..c78e376bd 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -11,7 +11,7 @@ from app.celery_app import celery_app from app.config import config from app.services.notification_service import NotificationService from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.tasks.connector_indexers.local_folder_indexer import ( index_local_folder, index_uploaded_files, @@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int): ) def delete_document_task(self, document_id: int): """Celery task to delete a document and its chunks in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_document_background(document_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _delete_document_background(document_id)) async def _delete_document_background(document_id: int) -> None: @@ -153,14 +148,9 @@ def delete_folder_documents_task( folder_subtree_ids: list[int] | None = None, ): """Celery task to delete documents first, then the folder rows.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _delete_folder_documents(document_ids, folder_subtree_ids) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_folder_documents(document_ids, folder_subtree_ids) + ) async def _delete_folder_documents( @@ -209,12 +199,9 @@ async def _delete_folder_documents( ) def delete_search_space_task(self, search_space_id: int): """Celery task to delete a search space and heavy child rows in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_search_space_background(search_space_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_search_space_background(search_space_id) + ) async def _delete_search_space_background(search_space_id: int) -> None: @@ -269,18 +256,11 @@ def process_extension_document_task( search_space_id: ID of the search space user_id: ID of the user """ - # Create a new event loop for this task - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_extension_document( - individual_document_dict, search_space_id, user_id - ) + return run_async_celery_task( + lambda: _process_extension_document( + individual_document_dict, search_space_id, user_id ) - finally: - loop.close() + ) async def _process_extension_document( @@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st search_space_id: ID of the search space user_id: ID of the user """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _process_youtube_video(url, search_space_id, user_id) + ) async def _process_youtube_video(url: str, search_space_id: int, user_id: str): @@ -573,12 +549,9 @@ def process_file_upload_task( except Exception as e: logger.warning(f"[process_file_upload] Could not get file size: {e}") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_upload(file_path, filename, search_space_id, user_id) + run_async_celery_task( + lambda: _process_file_upload(file_path, filename, search_space_id, user_id) ) logger.info( f"[process_file_upload] Task completed successfully for: {filename}" @@ -589,8 +562,6 @@ def process_file_upload_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _process_file_upload( @@ -811,25 +782,17 @@ def process_file_upload_with_document_task( "File may have been removed before syncing could start." ) # Mark document as failed since file is missing - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _mark_document_failed( - document_id, - "File not found. Please re-upload the file.", - ) + run_async_celery_task( + lambda: _mark_document_failed( + document_id, + "File not found. Please re-upload the file.", ) - finally: - loop.close() + ) return - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_with_document( + run_async_celery_task( + lambda: _process_file_with_document( document_id, temp_path, filename, @@ -849,8 +812,6 @@ def process_file_upload_with_document_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _mark_document_failed(document_id: int, reason: str): @@ -1119,22 +1080,16 @@ def process_circleback_meeting_task( search_space_id: ID of the search space connector_id: ID of the Circleback connector (for deletion support) """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_circleback_meeting( - meeting_id, - meeting_name, - markdown_content, - metadata, - search_space_id, - connector_id, - ) + return run_async_celery_task( + lambda: _process_circleback_meeting( + meeting_id, + meeting_name, + markdown_content, + metadata, + search_space_id, + connector_id, ) - finally: - loop.close() + ) async def _process_circleback_meeting( @@ -1291,25 +1246,19 @@ def index_local_folder_task( target_file_paths: list[str] | None = None, ): """Celery task to index a local folder. Config is passed directly — no connector row.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_local_folder_async( - search_space_id=search_space_id, - user_id=user_id, - folder_path=folder_path, - folder_name=folder_name, - exclude_patterns=exclude_patterns, - file_extensions=file_extensions, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - target_file_paths=target_file_paths, - ) + return run_async_celery_task( + lambda: _index_local_folder_async( + search_space_id=search_space_id, + user_id=user_id, + folder_path=folder_path, + folder_name=folder_name, + exclude_patterns=exclude_patterns, + file_extensions=file_extensions, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + target_file_paths=target_file_paths, ) - finally: - loop.close() + ) async def _index_local_folder_async( @@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task( processing_mode: str = "basic", ): """Celery task to index files uploaded from the desktop app.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_uploaded_folder_files_async( - search_space_id=search_space_id, - user_id=user_id, - folder_name=folder_name, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - file_mappings=file_mappings, - use_vision_llm=use_vision_llm, - processing_mode=processing_mode, - ) + return run_async_celery_task( + lambda: _index_uploaded_folder_files_async( + search_space_id=search_space_id, + user_id=user_id, + folder_name=folder_name, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + file_mappings=file_mappings, + use_vision_llm=use_vision_llm, + processing_mode=processing_mode, ) - finally: - loop.close() + ) async def _index_uploaded_folder_files_async( @@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str: @celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) def ai_sort_search_space_task(self, search_space_id: int, user_id: str): """Full AI sort for all documents in a search space.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_search_space_async(search_space_id, user_id) + ) async def _ai_sort_search_space_async(search_space_id: int, user_id: str): @@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str): ) def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): """Incremental AI sort for a single document after indexing.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _ai_sort_document_async(search_space_id, user_id, document_id) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_document_async(search_space_id, user_id, document_id) + ) async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): diff --git a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py index 98b107af3..c6c8666f5 100644 --- a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py @@ -2,14 +2,13 @@ from __future__ import annotations -import asyncio import logging from app.celery_app import celery_app from app.db import SearchSourceConnector from app.schemas.obsidian_plugin import NotePayload from app.services.obsidian_plugin_indexer import upsert_note -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -22,18 +21,13 @@ def index_obsidian_attachment_task( user_id: str, ) -> None: """Process one Obsidian non-markdown attachment asynchronously.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_obsidian_attachment( - connector_id=connector_id, - payload_data=payload_data, - user_id=user_id, - ) + return run_async_celery_task( + lambda: _index_obsidian_attachment( + connector_id=connector_id, + payload_data=payload_data, + user_id=user_id, ) - finally: - loop.close() + ) async def _index_obsidian_attachment( diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 953011ecf..8b311576e 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -3,14 +3,22 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState from app.celery_app import celery_app +from app.config import config as app_config from app.db import Podcast, PodcastStatus -from app.tasks.celery_tasks import get_celery_session_maker +from app.services.billable_calls import ( + BillingSettlementError, + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -28,6 +36,13 @@ if sys.platform.startswith("win"): # ============================================================================= +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_content_podcast", bind=True) def generate_content_podcast_task( self, @@ -40,27 +55,22 @@ def generate_content_podcast_task( Celery task to generate podcast from source content. Updates existing podcast record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_content_podcast( + return run_async_celery_task( + lambda: _generate_content_podcast( podcast_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating content podcast: {e!s}") - loop.run_until_complete(_mark_podcast_failed(podcast_id)) + try: + run_async_celery_task(lambda: _mark_podcast_failed(podcast_id)) + except Exception: + logger.exception("Failed to mark podcast %s as failed", podcast_id) return {"status": "failed", "podcast_id": podcast_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_podcast_failed(podcast_id: int) -> None: @@ -96,6 +106,31 @@ async def _generate_content_podcast( podcast.status = PodcastStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=podcast.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "Podcast %s: cannot resolve billing for search_space=%s: %s", + podcast.id, + search_space_id, + resolve_err, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "podcast_title": podcast.title, @@ -109,9 +144,52 @@ async def _generate_content_podcast( db_session=session, ) - graph_result = await podcaster_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, + usage_type="podcast_generation", + call_details={ + "podcast_id": podcast.id, + "title": podcast.title, + "thread_id": podcast.thread_id, + }, + billable_session_factory=_celery_billable_session, + ): + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "Podcast %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + podcast.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "premium_quota_exhausted", + } + except BillingSettlementError: + logger.exception( + "Podcast %s: premium billing settlement failed", + podcast.id, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_settlement_failed", + } podcast_transcript = graph_result.get("podcast_transcript", []) file_path = graph_result.get("final_podcast_file_path", "") @@ -133,7 +211,14 @@ async def _generate_content_podcast( podcast.podcast_transcript = serializable_transcript podcast.file_location = file_path podcast.status = PodcastStatus.READY + logger.info( + "Podcast %s: committing READY transcript_entries=%d file=%s", + podcast.id, + len(serializable_transcript), + file_path, + ) await session.commit() + logger.info("Podcast %s: READY commit complete", podcast.id) logger.info(f"Successfully generated podcast: {podcast.id}") diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 373f04b48..e41251407 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -7,7 +7,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.utils.indexing_locks import is_connector_indexing_locked logger = logging.getLogger(__name__) @@ -20,15 +20,7 @@ def check_periodic_schedules_task(): This task runs every minute and triggers indexing for any connector whose next_scheduled_at time has passed. """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_check_and_trigger_schedules()) - finally: - loop.close() + return run_async_celery_task(_check_and_trigger_schedules) async def _check_and_trigger_schedules(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index e05ae9435..d51c85dee 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -34,7 +34,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.config import config from app.db import Document, DocumentStatus, Notification -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task(): Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task. Also marks associated pending/processing documents as failed. """ - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + async def _both() -> None: + await _cleanup_stale_notifications() + await _cleanup_stale_document_processing_notifications() - try: - loop.run_until_complete(_cleanup_stale_notifications()) - loop.run_until_complete(_cleanup_stale_document_processing_notifications()) - finally: - loop.close() + return run_async_celery_task(_both) async def _cleanup_stale_notifications(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py index 3aee1a360..ace6ef7ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import logging from datetime import UTC, datetime, timedelta @@ -18,7 +17,7 @@ from app.db import ( PremiumTokenPurchaseStatus, ) from app.routes import stripe_routes -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None: @celery_app.task(name="reconcile_pending_stripe_page_purchases") def reconcile_pending_stripe_page_purchases_task(): """Recover paid purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_page_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_page_purchases) async def _reconcile_pending_page_purchases() -> None: @@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None: @celery_app.task(name="reconcile_pending_stripe_token_purchases") def reconcile_pending_stripe_token_purchases_task(): """Recover paid token purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_token_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_token_purchases) async def _reconcile_pending_token_purchases() -> None: diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index 7880b385f..08f22140c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -3,14 +3,22 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select from app.agents.video_presentation.graph import graph as video_presentation_graph from app.agents.video_presentation.state import State as VideoPresentationState from app.celery_app import celery_app +from app.config import config as app_config from app.db import VideoPresentation, VideoPresentationStatus -from app.tasks.celery_tasks import get_celery_session_maker +from app.services.billable_calls import ( + BillingSettlementError, + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -23,6 +31,13 @@ if sys.platform.startswith("win"): ) +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_video_presentation", bind=True) def generate_video_presentation_task( self, @@ -35,27 +50,30 @@ def generate_video_presentation_task( Celery task to generate video presentation from source content. Updates existing video presentation record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_video_presentation( + return run_async_celery_task( + lambda: _generate_video_presentation( video_presentation_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating video presentation: {e!s}") - loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id)) + # Mark FAILED in a fresh loop — the previous loop is closed. + # Swallow secondary failures; the row will simply stay in + # GENERATING and be flushed by the periodic stale cleanup. + try: + run_async_celery_task( + lambda: _mark_video_presentation_failed(video_presentation_id) + ) + except Exception: + logger.exception( + "Failed to mark video presentation %s as failed", + video_presentation_id, + ) return {"status": "failed", "video_presentation_id": video_presentation_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_video_presentation_failed(video_presentation_id: int) -> None: @@ -97,6 +115,32 @@ async def _generate_video_presentation( video_pres.status = VideoPresentationStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=video_pres.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "VideoPresentation %s: cannot resolve billing for " + "search_space=%s: %s", + video_pres.id, + search_space_id, + resolve_err, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "video_title": video_pres.title, @@ -110,9 +154,52 @@ async def _generate_video_presentation( db_session=session, ) - graph_result = await video_presentation_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, + usage_type="video_presentation_generation", + call_details={ + "video_presentation_id": video_pres.id, + "title": video_pres.title, + "thread_id": video_pres.thread_id, + }, + billable_session_factory=_celery_billable_session, + ): + graph_result = await video_presentation_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "VideoPresentation %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + video_pres.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "premium_quota_exhausted", + } + except BillingSettlementError: + logger.exception( + "VideoPresentation %s: premium billing settlement failed", + video_pres.id, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_settlement_failed", + } # Serialize slides (parsed content + audio info merged) slides_raw = graph_result.get("slides", []) @@ -143,7 +230,14 @@ async def _generate_video_presentation( video_pres.slides = serializable_slides video_pres.scene_codes = serializable_scene_codes video_pres.status = VideoPresentationStatus.READY + logger.info( + "VideoPresentation %s: committing READY slides=%d scene_codes=%d", + video_pres.id, + len(serializable_slides), + len(serializable_scene_codes), + ) await session.commit() + logger.info("VideoPresentation %s: READY commit complete", video_pres.id) logger.info(f"Successfully generated video presentation: {video_pres.id}") diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index dbfe9a67b..f7ddd8909 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.errors import BusyError from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection @@ -96,6 +97,47 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int: return min(delay, TURN_CANCELLING_MAX_DELAY_MS) +def _first_interrupt_value(state: Any) -> dict[str, Any] | None: + """Return the first LangGraph interrupt payload across all snapshot tasks.""" + + def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None: + if isinstance(candidate, dict): + value = candidate.get("value", candidate) + return value if isinstance(value, dict) else None + value = getattr(candidate, "value", None) + if isinstance(value, dict): + return value + if isinstance(candidate, (list, tuple)): + for item in candidate: + extracted = _extract_interrupt_value(item) + if extracted is not None: + return extracted + return None + + for task in getattr(state, "tasks", ()) or (): + try: + interrupts = getattr(task, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + interrupts = () + if not interrupts: + extracted = _extract_interrupt_value(task) + if extracted is not None: + return extracted + continue + for interrupt_item in interrupts: + extracted = _extract_interrupt_value(interrupt_item) + if extracted is not None: + return extracted + try: + state_interrupts = getattr(state, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + state_interrupts = () + extracted = _extract_interrupt_value(state_interrupts) + if extracted is not None: + return extracted + return None + + def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. @@ -518,6 +560,29 @@ async def _preflight_llm(llm: Any) -> None: ) +async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None: + """Wait for a discarded speculative agent build to release shared state. + + Used by the parallel preflight + agent-build path. The speculative build + closes over the request-scoped ``AsyncSession`` (for the brief connector + discovery / tool-factory window before its CPU work moves into a worker + thread). If preflight reports a 429 we want to fall back to the original + repin → reload → rebuild path, but we MUST NOT touch ``session`` again + until any in-flight session work owned by the speculative build has + fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not + concurrency-safe and the same hazard cost us a hard ``InvalidRequestError`` + earlier in this PR (see ``connector_service`` parallel-gather revert). + + We simply ``await`` the task and swallow any exception: in this path the + build's outcome is irrelevant — success populates the agent cache (a free + side effect), failure is discarded. The wasted CPU is acceptable since + 429 fallbacks are rare and the original sequential code also paid the + full build cost on the same path. + """ + with contextlib.suppress(BaseException): + await task + + def _classify_stream_exception( exc: Exception, *, @@ -655,6 +720,7 @@ async def _stream_agent_events( fallback_commit_created_by_id: str | None = None, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, fallback_commit_thread_id: int | None = None, + runtime_context: Any = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -760,7 +826,18 @@ async def _stream_agent_events( return event return None - async for event in agent.astream_events(input_data, config=config, version="v2"): + # Per-invocation runtime context (Phase 1.5). When supplied, + # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids`` + # from ``runtime.context`` instead of its constructor closure — the + # prerequisite that lets the compiled-agent cache (Phase 1) reuse a + # single graph across turns. Astream_events_kwargs stays empty when + # callers leave ``runtime_context`` as ``None`` to preserve the + # legacy code path bit-for-bit. + astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} + if runtime_context is not None: + astream_kwargs["context"] = runtime_context + + async for event in agent.astream_events(input_data, **astream_kwargs): event_type = event.get("event", "") if event_type == "on_chat_model_stream": @@ -1506,10 +1583,10 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else "Podcast" ) - if podcast_status == "processing": + if podcast_status in ("pending", "generating", "processing"): completed_items = [ f"Title: {podcast_title}", - "Audio generation started", + "Podcast generation started", "Processing in background...", ] elif podcast_status == "already_generating": @@ -1518,7 +1595,7 @@ async def _stream_agent_events( "Podcast already in progress", "Please wait for it to complete", ] - elif podcast_status == "error": + elif podcast_status in ("failed", "error"): error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) @@ -1528,6 +1605,11 @@ async def _stream_agent_events( f"Title: {podcast_title}", f"Error: {error_msg[:50]}", ] + elif podcast_status in ("ready", "success"): + completed_items = [ + f"Title: {podcast_title}", + "Podcast ready", + ] else: completed_items = last_active_step_items yield streaming_service.format_thinking_step( @@ -1710,20 +1792,28 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else {"result": tool_output}, ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "success" + if isinstance(tool_output, dict) and tool_output.get("status") in ( + "pending", + "generating", + "processing", + ): + yield streaming_service.format_terminal_info( + f"Podcast queued: {tool_output.get('title', 'Podcast')}", + "success", + ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "ready", + "success", ): yield streaming_service.format_terminal_info( f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", "success", ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "failed", + "error", + ): + error_msg = tool_output.get("error", "Unknown error") yield streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", @@ -2165,10 +2255,10 @@ async def _stream_agent_events( result.agent_called_update_memory = called_update_memory _log_file_contract("turn_outcome", result) - is_interrupted = state.tasks and any(task.interrupts for task in state.tasks) - if is_interrupted: + interrupt_value = _first_interrupt_value(state) + if interrupt_value is not None: result.is_interrupted = True - result.interrupt_value = state.tasks[0].interrupts[0].value + result.interrupt_value = interrupt_value yield streaming_service.format_interrupt_request(result.interrupt_value) @@ -2236,8 +2326,10 @@ async def stream_new_chat( accumulator = start_turn() - # Premium quota tracking state - _premium_reserved = 0 + # Premium credit (USD micro-units) tracking state. Stores the + # amount reserved up front so we can release it on cancellation + # and finalize-debit the actual provider cost reported by LiteLLM. + _premium_reserved_micros = 0 _premium_request_id: str | None = None _emit_stream_error = partial( @@ -2290,6 +2382,11 @@ async def stream_new_chat( ) _t0 = time.perf_counter() + # Image-bearing turns force the Auto-pin resolver to filter the + # candidate pool to vision-capable cfgs (and force-repin a + # text-only existing pin). For explicit selections this flag is + # a no-op — the resolver returns the user's chosen id unchanged. + _requires_image_input = bool(user_image_data_urls) try: llm_config_id = ( await resolve_or_get_pinned_llm_config_id( @@ -2298,13 +2395,29 @@ async def stream_new_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=llm_config_id, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: + # Auto-pin's "no vision-capable cfg" path raises a ValueError + # whose message we map to the friendly image-input SSE error + # so the user sees the same message regardless of whether + # the gate fired in Auto-mode or in the agent_config check + # below. + error_code = ( + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + if _requires_image_input and "vision-capable" in str(pin_error) + else "SERVER_ERROR" + ) + error_kind = ( + "user_error" + if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + else "server_error" + ) yield _emit_stream_error( message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", + error_kind=error_kind, + error_code=error_code, ) yield streaming_service.format_done() return @@ -2324,6 +2437,50 @@ async def stream_new_chat( llm_config_id, ) + # Capability safety net: a turn carrying user-uploaded images + # cannot be routed to a chat config that LiteLLM's authoritative + # model map *explicitly* marks as text-only (``supports_vision`` + # set to False). The check is intentionally narrow — it only + # fires when LiteLLM is *certain* the model can't accept image + # input. Unknown / unmapped / vision-capable models pass + # through. Without this guard a known-text-only model would 404 + # at the provider with ``"No endpoints found that support image + # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk; + # failing here lets us return a friendly message that tells the + # user what to change. + if user_image_data_urls and agent_config is not None: + from app.services.provider_capabilities import ( + is_known_text_only_chat_model, + ) + + agent_litellm_params = agent_config.litellm_params or {} + agent_base_model = ( + agent_litellm_params.get("base_model") + if isinstance(agent_litellm_params, dict) + else None + ) + if is_known_text_only_chat_model( + provider=agent_config.provider, + model_name=agent_config.model_name, + base_model=agent_base_model, + custom_provider=agent_config.custom_provider, + ): + model_label = ( + agent_config.config_name or agent_config.model_name or "model" + ) + yield _emit_stream_error( + message=( + f"The selected model ({model_label}) does not support " + "image input. Switch to a vision-capable model " + "(e.g. GPT-4o, Claude, Gemini) or remove the image " + "attachment and try again." + ), + error_kind="user_error", + error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + ) + yield streaming_service.format_done() + return + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id and agent_config.is_premium @@ -2331,23 +2488,28 @@ async def stream_new_chat( if _needs_premium_quota: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _agent_litellm_params = agent_config.litellm_params or {} + _agent_base_model = ( + _agent_litellm_params.get("base_model") or agent_config.model_name or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_agent_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _premium_reserved = reserve_amount + _premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: if requested_llm_config_id == 0: try: @@ -2359,6 +2521,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, force_repin_free=True, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: @@ -2382,7 +2545,7 @@ async def stream_new_chat( yield streaming_service.format_done() return _premium_request_id = None - _premium_reserved = 0 + _premium_reserved_micros = 0 _log_chat_stream_error( flow=flow, error_kind="premium_quota_exhausted", @@ -2433,23 +2596,102 @@ async def stream_new_chat( # Detecting a 429 here lets us repin BEFORE the planner/classifier/ # title-generation LLM calls fan out and each independently hit the # same upstream rate limit. - if ( + # + # PERF: preflight is a network round-trip to the LLM provider (~1-5s) + # and is independent of the agent build (CPU-bound, ~5-7s). They used + # to run sequentially → ``preflight + build`` on cold cache = 11.5s. + # We now kick off preflight as a background task FIRST, then run the + # synchronous setup work and the agent build in parallel. In the + # success path (the common case) total wall time drops to roughly + # ``max(preflight, build)`` — the preflight finishes during the + # agent compile and we just consume its result. In the rare 429 + # path the speculative build is awaited to completion (so its + # session usage is fully released) via + # :func:`_settle_speculative_agent_build`, then discarded, and + # we fall back to the original repin-and-rebuild flow. + preflight_needed = ( requested_llm_config_id == 0 and llm_config_id < 0 and not is_recently_healthy(llm_config_id) - ): + ) + preflight_task: asyncio.Task[None] | None = None + _t_preflight = 0.0 + if preflight_needed: _t_preflight = time.perf_counter() + preflight_task = asyncio.create_task( + _preflight_llm(llm), + name=f"auto_pin_preflight:{llm_config_id}", + ) + + # Create connector service + _t0 = time.perf_counter() + connector_service = ConnectorService(session, search_space_id=search_space_id) + + firecrawl_api_key = None + webcrawler_connector = await connector_service.get_connector_by_type( + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id + ) + if webcrawler_connector and webcrawler_connector.config: + firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") + _perf_log.info( + "[stream_new_chat] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, + ) + + # Get the PostgreSQL checkpointer for persistent conversation memory + _t0 = time.perf_counter() + checkpointer = await get_checkpointer() + _perf_log.info( + "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + visibility = thread_visibility or ChatVisibility.PRIVATE + _t0 = time.perf_counter() + # Speculative agent build — runs in parallel with the preflight + # task (if any). Built with the *current* ``llm`` / ``agent_config``; + # if preflight reports 429 we will discard this future and rebuild + # against the freshly pinned config below. + agent_build_task = asyncio.create_task( + create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ), + name="agent_build:stream_new_chat", + ) + + agent: Any = None + if preflight_task is not None: try: - await _preflight_llm(llm) + await preflight_task mark_healthy(llm_config_id) _perf_log.info( - "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs", + "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)", llm_config_id, time.perf_counter() - _t_preflight, ) except Exception as preflight_exc: + # Both branches below need the session: the non-429 path + # may unwind via cleanup that uses ``session``, and the + # 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id`` + # against it. Wait for the speculative build to release its + # session usage before we proceed. + await _settle_speculative_agent_build(agent_build_task) if not _is_provider_rate_limited(preflight_exc): raise + # 429: speculative agent is discarded; run the original + # repin → reload → rebuild path against the freshly + # pinned config. previous_config_id = llm_config_id mark_runtime_cooldown( previous_config_id, reason="preflight_rate_limited" @@ -2463,6 +2705,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: @@ -2511,46 +2754,28 @@ async def stream_new_chat( "fallback_config_id": llm_config_id, }, ) + # Rebuild against the new llm/agent_config. Sequential + # here because we no longer have anything to overlap with. + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) - # Create connector service - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_new_chat] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - # Get the PostgreSQL checkpointer for persistent conversation memory - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - _t0 = time.perf_counter() - agent = await create_surfsense_deep_agent( - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - filesystem_selection=filesystem_selection, - ) + if agent is None: + # Either no preflight was needed, or preflight succeeded — + # in both cases the speculative build is the agent we want. + agent = await agent_build_task _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 ) @@ -2797,6 +3022,7 @@ async def stream_new_chat( from litellm import acompletion from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator _turn_accumulator.set(None) @@ -2817,11 +3043,32 @@ async def stream_new_chat( model="auto", messages=messages ) else: + # Apply the same ``api_base`` cascade chat / vision / + # image-gen call sites use so we never inherit + # ``litellm.api_base`` (commonly set by + # ``AZURE_OPENAI_ENDPOINT``) when the chat config + # itself ships an empty ``api_base``. Without this + # the title-gen on an OpenRouter chat config would + # 404 against the inherited Azure endpoint — see + # ``provider_api_base`` docstring for the same + # bug repro on the image-gen / vision paths. + raw_model = getattr(llm, "model", "") or "" + provider_prefix = ( + raw_model.split("/", 1)[0] if "/" in raw_model else None + ) + provider_value = ( + agent_config.provider if agent_config is not None else None + ) + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) response = await acompletion( - model=llm.model, + model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=getattr(llm, "api_base", None), + api_base=title_api_base, ) usage_info = None @@ -2855,6 +3102,18 @@ async def stream_new_chat( title_emitted = False + # Build the per-invocation runtime context (Phase 1.5). + # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware`` + # via ``runtime.context.mentioned_document_ids`` instead of its + # ``__init__`` closure — that way the same compiled-agent instance + # can serve multiple turns with different mention lists. + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + mentioned_document_ids=list(mentioned_document_ids or []), + request_id=request_id, + turn_id=stream_result.turn_id, + ) + _t_stream_start = time.perf_counter() _first_event_logged = False runtime_rate_limit_recovered = False @@ -2878,6 +3137,7 @@ async def stream_new_chat( else FilesystemMode.CLOUD ), fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, ): if not _first_event_logged: _perf_log.info( @@ -2946,6 +3206,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id @@ -3020,9 +3281,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3033,6 +3295,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3060,7 +3323,11 @@ async def stream_new_chat( chat_id, generated_title ) - # Finalize premium quota with actual tokens. + # Finalize premium credit debit with the actual provider cost + # reported by LiteLLM, summed across every call in the turn. + # Mirrors the pre-cost behaviour of "premium turn → all calls + # count" so free sub-agent calls during a premium turn still + # contribute to the bill (they're $0 in practice anyway). if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3070,11 +3337,11 @@ async def stream_new_chat( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=accumulator.grand_total, - reserved_tokens=_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_premium_reserved_micros, ) _premium_request_id = None - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -3084,9 +3351,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal new_chat: calls=%d total=%d summary=%s", + "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3097,6 +3365,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3190,7 +3459,7 @@ async def stream_new_chat( end_turn(str(chat_id)) # Release premium reservation if not finalized - if _premium_request_id and _premium_reserved > 0 and user_id: + if _premium_request_id and _premium_reserved_micros > 0 and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3198,9 +3467,9 @@ async def stream_new_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_premium_reserved, + reserved_micros=_premium_reserved_micros, ) - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s", user_id @@ -3369,8 +3638,8 @@ async def stream_resume_chat( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) - # Premium quota reservation (same logic as stream_new_chat) - _resume_premium_reserved = 0 + # Premium credit reservation (same logic as stream_new_chat). + _resume_premium_reserved_micros = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( agent_config is not None and user_id and agent_config.is_premium @@ -3378,23 +3647,30 @@ async def stream_resume_chat( if _resume_needs_premium: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _resume_premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _resume_litellm_params = agent_config.litellm_params or {} + _resume_base_model = ( + _resume_litellm_params.get("base_model") + or agent_config.model_name + or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_resume_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _resume_premium_reserved = reserve_amount + _resume_premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: if requested_llm_config_id == 0: try: @@ -3429,7 +3705,7 @@ async def stream_resume_chat( yield streaming_service.format_done() return _resume_premium_request_id = None - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 _log_chat_stream_error( flow="resume", error_kind="premium_quota_exhausted", @@ -3477,21 +3753,75 @@ async def stream_resume_chat( # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: # one cheap probe before the agent is rebuilt so a 429'd pin gets # repinned without burning planner/classifier/title calls first. - if ( + # See ``stream_new_chat`` for the full rationale on the speculative + # parallel build pattern below. + preflight_needed = ( requested_llm_config_id == 0 and llm_config_id < 0 and not is_recently_healthy(llm_config_id) - ): + ) + preflight_task: asyncio.Task[None] | None = None + _t_preflight = 0.0 + if preflight_needed: _t_preflight = time.perf_counter() + preflight_task = asyncio.create_task( + _preflight_llm(llm), + name=f"auto_pin_preflight_resume:{llm_config_id}", + ) + + _t0 = time.perf_counter() + connector_service = ConnectorService(session, search_space_id=search_space_id) + + firecrawl_api_key = None + webcrawler_connector = await connector_service.get_connector_by_type( + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id + ) + if webcrawler_connector and webcrawler_connector.config: + firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") + _perf_log.info( + "[stream_resume] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, + ) + + _t0 = time.perf_counter() + checkpointer = await get_checkpointer() + _perf_log.info( + "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + visibility = thread_visibility or ChatVisibility.PRIVATE + + _t0 = time.perf_counter() + agent_build_task = asyncio.create_task( + create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ), + name="agent_build:stream_resume", + ) + + agent: Any = None + if preflight_task is not None: try: - await _preflight_llm(llm) + await preflight_task mark_healthy(llm_config_id) _perf_log.info( - "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs", + "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)", llm_config_id, time.perf_counter() - _t_preflight, ) except Exception as preflight_exc: + # Same session-safety rationale as ``stream_new_chat``. + await _settle_speculative_agent_build(agent_build_task) if not _is_provider_rate_limited(preflight_exc): raise previous_config_id = llm_config_id @@ -3551,43 +3881,22 @@ async def stream_resume_chat( "fallback_config_id": llm_config_id, }, ) + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_resume] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - - _t0 = time.perf_counter() - agent = await create_surfsense_deep_agent( - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - ) + if agent is None: + agent = await agent_build_task _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 ) @@ -3628,6 +3937,16 @@ async def stream_resume_chat( ) yield streaming_service.format_data("turn-status", {"status": "busy"}) + # Resume path doesn't carry new ``mentioned_document_ids`` — + # those are seeded in the original turn. We still pass a + # context so future middleware extensions (Phase 2) can rely on + # ``runtime.context`` always being populated. + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + request_id=request_id, + turn_id=stream_result.turn_id, + ) + _t_stream_start = time.perf_counter() _first_event_logged = False runtime_rate_limit_recovered = False @@ -3648,6 +3967,7 @@ async def stream_resume_chat( else FilesystemMode.CLOUD ), fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, ): if not _first_event_logged: _perf_log.info( @@ -3746,9 +4066,10 @@ async def stream_resume_chat( if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3759,6 +4080,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3768,7 +4090,9 @@ async def stream_resume_chat( yield streaming_service.format_done() return - # Finalize premium quota for resume path + # Finalize premium credit debit for resume path with the actual + # provider cost reported by LiteLLM (sum of cost across all + # calls in the turn). if _resume_premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3778,11 +4102,11 @@ async def stream_resume_chat( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=accumulator.grand_total, - reserved_tokens=_resume_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_resume_premium_reserved_micros, ) _resume_premium_request_id = None - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", @@ -3792,9 +4116,10 @@ async def stream_resume_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", + "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3805,6 +4130,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3855,7 +4181,11 @@ async def stream_resume_chat( end_turn(str(chat_id)) # Release premium reservation if not finalized - if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: + if ( + _resume_premium_request_id + and _resume_premium_reserved_micros > 0 + and user_id + ): try: from app.services.token_quota_service import TokenQuotaService @@ -3863,9 +4193,9 @@ async def stream_resume_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_resume_premium_reserved, + reserved_micros=_resume_premium_reserved_micros, ) - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s (resume)", user_id diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index 6912ffe5a..3c9f27303 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from .base import ( check_duplicate_document_by_hash, @@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]] HEARTBEAT_INTERVAL_SECONDS = 30 +def _format_calendar_event_to_markdown(event: dict) -> str: + return GoogleCalendarConnector.format_event_to_markdown(None, event) + + def _build_connector_doc( event: dict, event_markdown: str, @@ -150,7 +152,14 @@ async def index_google_calendar_events( ) return 0, 0, f"Connector with ID {connector_id} not found" - # ── Credential building ─────────────────────────────────────── + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) + calendar_client = None + composio_service = None + connected_account_id = None + + # ── Credential/client building ──────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -161,7 +170,7 @@ async def index_google_calendar_events( {"error_type": "MissingComposioAccount"}, ) return 0, 0, "Composio connected_account_id not found" - credentials = build_composio_credentials(connected_account_id) + composio_service = ComposioService() else: config_data = connector.config @@ -229,12 +238,13 @@ async def index_google_calendar_events( {"stage": "client_initialization"}, ) - calendar_client = GoogleCalendarConnector( - credentials=credentials, - session=session, - user_id=user_id, - connector_id=connector_id, - ) + if not is_composio_connector: + calendar_client = GoogleCalendarConnector( + credentials=credentials, + session=session, + user_id=user_id, + connector_id=connector_id, + ) # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": @@ -300,9 +310,26 @@ async def index_google_calendar_events( ) try: - events, error = await calendar_client.get_all_primary_calendar_events( - start_date=start_date_str, end_date=end_date_str - ) + if is_composio_connector: + start_dt = parse_date_flexible(start_date_str).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + end_dt = parse_date_flexible(end_date_str).replace( + hour=23, minute=59, second=59, microsecond=0 + ) + events, error = await composio_service.get_calendar_events( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + time_min=start_dt.isoformat(), + time_max=end_dt.isoformat(), + max_results=250, + ) + if not events and not error: + error = "No events found in the specified date range." + else: + events, error = await calendar_client.get_all_primary_calendar_events( + start_date=start_date_str, end_date=end_date_str + ) if error: if "No events found" in error: @@ -381,7 +408,7 @@ async def index_google_calendar_events( documents_skipped += 1 continue - event_markdown = calendar_client.format_event_to_markdown(event) + event_markdown = _format_calendar_event_to_markdown(event) if not event_markdown.strip(): logger.warning(f"Skipping event with no content: {event_summary}") documents_skipped += 1 diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 21cdbd29f..686f13d9e 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -9,6 +9,8 @@ import asyncio import logging import time from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any from sqlalchemy import String, cast, select from sqlalchemy.exc import SQLAlchemyError @@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService @@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import ( get_connector_by_id, update_connector_last_indexed, ) -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES ACCEPTED_DRIVE_CONNECTOR_TYPES = { SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, @@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30 logger = logging.getLogger(__name__) +class ComposioDriveClient: + """Google Drive client facade backed by Composio tool execution. + + Composio-managed OAuth connections can execute tools without exposing raw + OAuth tokens through connected account state. + """ + + def __init__( + self, + session: AsyncSession, + connector_id: int, + connected_account_id: str, + entity_id: str, + ): + self.session = session + self.connector_id = connector_id + self.connected_account_id = connected_account_id + self.entity_id = entity_id + self.composio = ComposioService() + + async def list_files( + self, + query: str = "", + fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)", + page_size: int = 100, + page_token: str | None = None, + ) -> tuple[list[dict[str, Any]], str | None, str | None]: + params: dict[str, Any] = { + "page_size": min(page_size, 100), + "fields": fields, + } + if query: + params["q"] = query + if page_token: + params["page_token"] = page_token + + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_LIST_FILES", + params=params, + entity_id=self.entity_id, + ) + if not result.get("success"): + return [], None, result.get("error", "Unknown error") + + data = result.get("data", {}) + files = [] + next_token = None + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + files = inner_data.get("files", []) + next_token = inner_data.get("nextPageToken") or inner_data.get( + "next_page_token" + ) + elif isinstance(data, list): + files = data + + return files, next_token, None + + async def get_file_metadata( + self, file_id: str, fields: str = "*" + ) -> tuple[dict[str, Any] | None, str | None]: + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_GET_FILE_METADATA", + params={"file_id": file_id, "fields": fields}, + entity_id=self.entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + data = result.get("data", {}) + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + return inner_data, None + + return None, "Could not extract metadata from Composio response" + + async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]: + return await self._download_file_content(file_id) + + async def download_file_to_disk( + self, + file_id: str, + dest_path: str, + chunksize: int = 5 * 1024 * 1024, + ) -> str | None: + del chunksize + content, error = await self.download_file(file_id) + if error: + return error + if content is None: + return "No content returned from Composio" + Path(dest_path).write_bytes(content) + return None + + async def export_google_file( + self, file_id: str, mime_type: str + ) -> tuple[bytes | None, str | None]: + return await self._download_file_content(file_id, mime_type=mime_type) + + async def _download_file_content( + self, file_id: str, mime_type: str | None = None + ) -> tuple[bytes | None, str | None]: + params: dict[str, Any] = {"file_id": file_id} + if mime_type: + params["mime_type"] = mime_type + + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_DOWNLOAD_FILE", + params=params, + entity_id=self.entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + return self._read_download_result(result.get("data")) + + def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]: + if isinstance(data, bytes): + return data, None + + file_path: str | None = None + if isinstance(data, str): + file_path = data + elif isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + for key in ("file_path", "downloaded_file_content", "path", "uri"): + value = inner_data.get(key) + if isinstance(value, str): + file_path = value + break + if isinstance(value, dict): + nested = ( + value.get("file_path") + or value.get("downloaded_file_content") + or value.get("path") + or value.get("uri") + or value.get("s3url") + ) + if isinstance(nested, str): + file_path = nested + break + + if not file_path: + return None, "No file path/content returned from Composio" + + if file_path.startswith(("http://", "https://")): + try: + import urllib.request + + with urllib.request.urlopen(file_path, timeout=60) as response: + return response.read(), None + except Exception as e: + return None, f"Failed to download Composio file URL: {e!s}" + + path_obj = Path(file_path) + if path_obj.is_absolute() or ".composio" in str(path_obj): + if not path_obj.exists(): + return None, f"File not found at path: {file_path}" + return path_obj.read_bytes(), None + + try: + import base64 + + return base64.b64decode(file_path), None + except Exception: + return file_path.encode("utf-8"), None + + +def _build_drive_client_for_connector( + session: AsyncSession, + connector_id: int, + connector: object, + user_id: str, +) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]: + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + connected_account_id = connector.config.get("composio_connected_account_id") + if not connected_account_id: + return None, ( + f"Composio connected_account_id not found for connector {connector_id}" + ) + return ( + ComposioDriveClient( + session, + connector_id, + connected_account_id, + entity_id=f"surfsense_{user_id}", + ), + None, + ) + + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and not config.SECRET_KEY: + return None, "SECRET_KEY not configured but credentials are marked as encrypted" + + return GoogleDriveClient(session, connector_id), None + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -927,34 +1130,17 @@ async def index_google_drive_files( {"stage": "client_initialization"}, ) - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, 0, error_msg, 0 - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - "SECRET_KEY not configured but credentials are encrypted", - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - 0, - ) + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + await task_logger.log_task_failure( + log_entry, + client_error or "Failed to initialize Google Drive client", + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, 0, client_error, 0 connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -963,10 +1149,6 @@ async def index_google_drive_files( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - if not folder_id: error_msg = "folder_id is required for Google Drive indexing" await task_logger.log_task_failure( @@ -979,8 +1161,14 @@ async def index_google_drive_files( folder_tokens = connector.config.get("folder_tokens", {}) start_page_token = folder_tokens.get(target_folder_id) + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) can_use_delta = ( - use_delta_sync and start_page_token and connector.last_indexed_at + not is_composio_connector + and use_delta_sync + and start_page_token + and connector.last_indexed_at ) documents_unsupported = 0 @@ -1051,7 +1239,16 @@ async def index_google_drive_files( ) if documents_indexed > 0 or can_use_delta: - new_token, token_error = await get_start_page_token(drive_client) + if isinstance(drive_client, ComposioDriveClient): + ( + new_token, + token_error, + ) = await drive_client.composio.get_drive_start_page_token( + drive_client.connected_account_id, + drive_client.entity_id, + ) + else: + new_token, token_error = await get_start_page_token(drive_client) if new_token and not token_error: await session.refresh(connector) if "folder_tokens" not in connector.config: @@ -1137,32 +1334,17 @@ async def index_google_drive_single_file( ) return 0, error_msg - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, error_msg - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - "SECRET_KEY not configured but credentials are encrypted", - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - ) + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + await task_logger.log_task_failure( + log_entry, + client_error or "Failed to initialize Google Drive client", + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, client_error connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -1171,10 +1353,6 @@ async def index_google_drive_single_file( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - file, error = await get_file_by_id(drive_client, file_id) if error or not file: error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}" @@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files( ) return 0, 0, [error_msg] - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, 0, [error_msg] - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - error_msg = ( - "SECRET_KEY not configured but credentials are marked as encrypted" - ) - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return 0, 0, [error_msg] + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + error_msg = client_error or "Failed to initialize Google Drive client" + await task_logger.log_task_failure( + log_entry, + error_msg, + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, 0, [error_msg] connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - indexed, skipped, unsupported, errors = await _index_selected_files( drive_client, session, diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index ef226087b..6697c0eb1 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from .base import ( calculate_date_range, @@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]] HEARTBEAT_INTERVAL_SECONDS = 30 +def _normalize_composio_gmail_message(message: dict) -> dict: + if message.get("payload"): + return message + + headers = [] + header_values = { + "Subject": message.get("subject"), + "From": message.get("from") or message.get("sender"), + "To": message.get("to") or message.get("recipient"), + "Date": message.get("date"), + } + for name, value in header_values.items(): + if value: + headers.append({"name": name, "value": value}) + + return { + **message, + "id": message.get("id") + or message.get("message_id") + or message.get("messageId"), + "threadId": message.get("threadId") or message.get("thread_id"), + "payload": {"headers": headers}, + "snippet": message.get("snippet", ""), + "messageText": message.get("messageText") or message.get("body") or "", + } + + +def _format_gmail_message_to_markdown(message: dict) -> str: + headers = { + header.get("name", "").lower(): header.get("value", "") + for header in message.get("payload", {}).get("headers", []) + if isinstance(header, dict) + } + subject = headers.get("subject", "No Subject") + from_email = headers.get("from", "Unknown Sender") + to_email = headers.get("to", "Unknown Recipient") + date_str = headers.get("date", "Unknown Date") + message_text = ( + message.get("messageText") + or message.get("body") + or message.get("text") + or message.get("snippet", "") + ) + + return ( + f"# {subject}\n\n" + f"**From:** {from_email}\n" + f"**To:** {to_email}\n" + f"**Date:** {date_str}\n\n" + f"## Message Content\n\n{message_text}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {message.get('id', 'Unknown')}\n" + f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n" + ) + + def _build_connector_doc( message: dict, markdown_content: str, @@ -162,7 +216,14 @@ async def index_google_gmail_messages( ) return 0, 0, error_msg - # ── Credential building ─────────────────────────────────────── + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) + gmail_connector = None + composio_service = None + connected_account_id = None + + # ── Credential/client building ──────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -173,7 +234,7 @@ async def index_google_gmail_messages( {"error_type": "MissingComposioAccount"}, ) return 0, 0, "Composio connected_account_id not found" - credentials = build_composio_credentials(connected_account_id) + composio_service = ComposioService() else: config_data = connector.config @@ -241,9 +302,10 @@ async def index_google_gmail_messages( {"stage": "client_initialization"}, ) - gmail_connector = GoogleGmailConnector( - credentials, session, user_id, connector_id - ) + if not is_composio_connector: + gmail_connector = GoogleGmailConnector( + credentials, session, user_id, connector_id + ) calculated_start_date, calculated_end_date = calculate_date_range( connector, start_date, end_date, default_days_back=365 @@ -254,11 +316,60 @@ async def index_google_gmail_messages( f"Fetching emails for connector {connector_id} " f"from {calculated_start_date} to {calculated_end_date}" ) - messages, error = await gmail_connector.get_recent_messages( - max_results=max_messages, - start_date=calculated_start_date, - end_date=calculated_end_date, - ) + if is_composio_connector: + query_parts = [] + if calculated_start_date: + query_parts.append(f"after:{calculated_start_date.replace('-', '/')}") + if calculated_end_date: + query_parts.append(f"before:{calculated_end_date.replace('-', '/')}") + query = " ".join(query_parts) + + messages = [] + page_token = None + error = None + while len(messages) < max_messages: + page_size = min(50, max_messages - len(messages)) + ( + page_messages, + page_token, + _estimate, + page_error, + ) = await composio_service.get_gmail_messages( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=page_size, + page_token=page_token, + ) + if page_error: + error = page_error + break + for page_message in page_messages: + message_id = ( + page_message.get("id") + or page_message.get("message_id") + or page_message.get("messageId") + ) + if message_id: + ( + detail, + detail_error, + ) = await composio_service.get_gmail_message_detail( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if not detail_error and isinstance(detail, dict): + page_message = detail + messages.append(_normalize_composio_gmail_message(page_message)) + if not page_token: + break + else: + messages, error = await gmail_connector.get_recent_messages( + max_results=max_messages, + start_date=calculated_start_date, + end_date=calculated_end_date, + ) if error: error_message = error @@ -326,7 +437,12 @@ async def index_google_gmail_messages( documents_skipped += 1 continue - markdown_content = gmail_connector.format_message_to_markdown(message) + if is_composio_connector: + markdown_content = _format_gmail_message_to_markdown(message) + else: + markdown_content = gmail_connector.format_message_to_markdown( + message + ) if not markdown_content.strip(): logger.warning(f"Skipping message with no content: {message_id}") documents_skipped += 1 diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index cd683e2e1..b2bf17305 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.19" +version = "0.0.20" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ @@ -71,11 +71,11 @@ dependencies = [ "langchain>=1.2.13", "langgraph>=1.1.3", "langchain-community>=0.4.1", - "deepagents>=0.4.12", "stripe>=15.0.0", "azure-ai-documentintelligence>=1.0.2", "litellm>=1.83.7", "langchain-litellm>=0.6.4", + "deepagents>=0.4.12,<0.5", ] [dependency-groups] diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py new file mode 100644 index 000000000..a49d4eab2 --- /dev/null +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -0,0 +1,558 @@ +"""End-to-end smoke test for vision / image config wiring. + +Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and +exercises every chat / vision / image-generation config + the OpenRouter +dynamic catalog. For each config the script: + +1. Reports the resolver classification (catalog-allow vs strict-block). +2. Optionally fires a tiny live API call against the provider: + - Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt + ``"reply with one word: ok"``. + - Vision configs: same, against the dedicated vision router pool. + - Image-gen configs: ``litellm.aimage_generation`` with a single tiny + prompt and ``n=1``. + - OpenRouter integration: samples one chat, one vision, one image-gen + model from the dynamically fetched catalog. + +Usage:: + + python -m scripts.verify_chat_image_capability # capability + connectivity + python -m scripts.verify_chat_image_capability --no-live # capability resolver only + +The script is meant to be runnable from the repository root or from +``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the +end so it's usable as a CI smoke check too. + +Live-mode caveat: each successful call costs a small amount of provider +credit (a few tokens or one tiny generated image per config). The +default size for image generation is ``1024x1024`` because Azure +GPT-image deployments reject smaller sizes; OpenRouter image-gen models +generally accept the same size. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Any + +# Bootstrap the surfsense_backend package on sys.path so the script runs +# from the repo root or from `surfsense_backend/` interchangeably. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BACKEND_ROOT = os.path.dirname(_HERE) +if _BACKEND_ROOT not in sys.path: + sys.path.insert(0, _BACKEND_ROOT) + +import litellm # noqa: E402 + +from app.config import config # noqa: E402 +from app.services.openrouter_integration_service import ( # noqa: E402 + _OPENROUTER_DYNAMIC_MARKER, + OpenRouterIntegrationService, +) +from app.services.provider_api_base import resolve_api_base # noqa: E402 +from app.services.provider_capabilities import ( # noqa: E402 + derive_supports_image_input, + is_known_text_only_chat_model, +) + +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", +) +# Quiet down LiteLLM's verbose router/cost logs so the script output is +# scannable. +logging.getLogger("LiteLLM").setLevel(logging.ERROR) +logging.getLogger("litellm").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) + +# 1x1 transparent PNG — used as the cheapest possible vision payload. +_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" +_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}" + + +# --------------------------------------------------------------------------- +# Result accounting +# --------------------------------------------------------------------------- + + +@dataclass +class ProbeResult: + label: str + surface: str + config_id: int | str + capability_ok: bool | None = None + capability_note: str = "" + live_ok: bool | None = None + live_note: str = "" + duration_s: float = 0.0 + + +@dataclass +class Report: + results: list[ProbeResult] = field(default_factory=list) + + def add(self, r: ProbeResult) -> None: + self.results.append(r) + + def render(self) -> int: + passed = failed = skipped = 0 + print() + print("=" * 92) + print( + f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes" + ) + print("-" * 92) + for r in self.results: + + def _flag(value: bool | None) -> str: + if value is None: + return "skip" + return "ok" if value else "fail" + + cap = _flag(r.capability_ok) + live = _flag(r.live_ok) + if r.capability_ok is False or r.live_ok is False: + failed += 1 + elif r.capability_ok is None and r.live_ok is None: + skipped += 1 + else: + passed += 1 + print( + f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} " + f"{r.duration_s:>5.2f}s {r.label}" + ) + if r.capability_note: + print(f" cap: {r.capability_note}") + if r.live_note: + print(f" live: {r.live_note}") + print("-" * 92) + print( + f"Total: {passed} ok / {failed} fail / {skipped} skip " + f"(of {len(self.results)} probes)" + ) + print("=" * 92) + return failed + + +# --------------------------------------------------------------------------- +# Capability probes (no network) +# --------------------------------------------------------------------------- + + +def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: + """For chat configs the catalog flag is *expected* True (vision-capable + pool). The probe reports both the resolver value and the strict + safety-net value to surface any drift between them.""" + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + cap = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + block = is_known_text_only_chat_model( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + note = f"derive={cap} strict_block={block}" + if not cap and not block: + # Resolver said False but strict gate is also False — that means + # OR modalities published [text] explicitly. Surface it. + note += " (OR modality says text-only)" + # We accept a True derive *or* (False derive AND False block) as + # 'capability ok' — either way, the streaming task will flow through. + ok = cap or not block + return ok, note + + +def _build_chat_model_string(cfg: dict) -> str: + if cfg.get("custom_provider"): + return f"{cfg['custom_provider']}/{cfg['model_name']}" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + return f"{prefix}/{cfg['model_name']}" + + +# --------------------------------------------------------------------------- +# Live probes (network calls) +# --------------------------------------------------------------------------- + + +async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: + """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" + model_string = _build_chat_model_string(cfg) + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=model_string.split("/", 1)[0], + config_api_base=cfg.get("api_base") or None, + ) + kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "reply with one word: ok"}, + { + "type": "image_url", + "image_url": {"url": _TINY_PNG_DATA_URL}, + }, + ], + } + ], + "max_tokens": 16, + "timeout": 60, + } + if api_base: + kwargs["api_base"] = api_base + if cfg.get("litellm_params"): + # Strip pricing keys — they're tracking-only and confuse some + # provider validators (e.g. azure/openai reject unknown kwargs + # in strict mode). + merged = { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + kwargs.update(merged) + try: + resp = await litellm.acompletion(**kwargs) + except Exception as exc: + return False, f"{type(exc).__name__}: {exc}" + text = resp.choices[0].message.content if resp.choices else "" + return True, f"got reply ({(text or '').strip()[:40]!r})" + + +# Gemini image models occasionally return zero-length ``data`` for the +# minimal "red dot on white" prompt (provider-side safety / empty-output +# quirk reproducible against ``google/gemini-2.5-flash-image`` even when +# the request itself succeeds). Use a more naturalistic prompt and +# retry once with a different one before giving up. +_IMAGE_GEN_PROMPTS: tuple[str, ...] = ( + "A simple icon of a coffee cup, flat illustration", + "A small green leaf on a white background", +) + + +async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: + """Generate one tiny image to verify the deployment is reachable.""" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + if cfg.get("custom_provider"): + prefix = cfg["custom_provider"] + else: + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + model_string = f"{prefix}/{cfg['model_name']}" + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=prefix, + config_api_base=cfg.get("api_base") or None, + ) + base_kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "n": 1, + "size": "1024x1024", + "timeout": 120, + } + if api_base: + base_kwargs["api_base"] = api_base + if cfg.get("api_version"): + base_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + base_kwargs.update( + { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + ) + + last_note = "" + for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1): + try: + resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs) + except Exception as exc: + last_note = f"{type(exc).__name__}: {exc}" + continue + data_count = len(getattr(resp, "data", None) or []) + if data_count > 0: + return True, ( + f"received {data_count} image(s) on attempt {attempt} " + f"(prompt={prompt!r})" + ) + last_note = ( + f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})" + ) + return False, last_note + + +# --------------------------------------------------------------------------- +# Probe drivers +# --------------------------------------------------------------------------- + + +def _is_or_dynamic(cfg: dict) -> bool: + return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER)) + + +async def probe_chat_configs(report: Report, *, live: bool) -> None: + print("\n[chat configs from global_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_LLM_CONFIGS: + # Skip OR dynamic entries here — handled in the OR section so + # the YAML / OR split stays clear in the report. + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="chat-yaml", + config_id=cfg.get("id"), + ) + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_vision_configs(report: Report, *, live: bool) -> None: + print("\n[vision configs from global_vision_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="vision", + config_id=cfg.get("id"), + ) + # For vision configs, capability is implied — they're in the + # dedicated vision pool. Run the same resolver to flag any + # surprise disagreement. + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_image_gen_configs(report: Report, *, live: bool) -> None: + print( + "\n[image generation configs from global_image_generation_configs (YAML-static)]" + ) + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="image-gen", + config_id=cfg.get("id"), + ) + # Image gen configs don't have a "supports_image_input" flag; + # the catalog tracks output, not input. Mark capability as None + # (skip) for the report. + if live: + t0 = time.perf_counter() + ok, note = await _live_image_gen_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: + """Sample one chat (vision-capable), one vision, one image-gen model + from the live OpenRouter catalogue. Doesn't iterate the full pool + (would be hundreds of probes); just validates the integration end- + to-end on a representative model from each surface.""" + print("\n[OpenRouter integration: sampled probes]") + settings = config.OPENROUTER_INTEGRATION_SETTINGS + if not settings: + report.add( + ProbeResult( + label="OpenRouter integration", + surface="openrouter", + config_id="settings", + capability_ok=None, + capability_note="openrouter_integration disabled in YAML — skipping", + live_ok=None, + ) + ) + return + + service = OpenRouterIntegrationService.get_instance() + or_chat = [ + c + for c in config.GLOBAL_LLM_CONFIGS + if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") + ] + or_vision = [ + c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" + ] + or_image_gen = [ + c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" + ] + + # Pick one representative per provider family per surface so a single + # broken vendor (e.g. Anthropic key revoked, Google quota exceeded) + # surfaces independently of the others. Each needle matches the + # OpenRouter ``model_name`` prefix; the first match wins. + def _pick_first(pool: list[dict], needle: str) -> dict | None: + for c in pool: + if (c.get("model_name") or "").lower().startswith(needle): + return c + return None + + chat_picks = [ + ("or-chat", _pick_first(or_chat, "openai/gpt-4o")), + ("or-chat", _pick_first(or_chat, "anthropic/claude")), + ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), + ] + vision_picks = [ + ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), + ("or-vision", _pick_first(or_vision, "anthropic/claude")), + ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), + ] + image_picks = [ + ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), + # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` + # / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match + # the actual prefix. + ("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")), + ] + + print( + f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" + ) + + for surface, picked in chat_picks + vision_picks + image_picks: + if not picked: + report.add( + ProbeResult( + label=f"", + surface=surface, + config_id="-", + capability_ok=None, + capability_note="no candidate found in OR catalog", + ) + ) + continue + runner = ( + _live_image_gen_call if surface == "or-image" else _live_chat_image_call + ) + result = ProbeResult( + label=str(picked.get("model_name")), + surface=surface, + config_id=picked.get("id"), + ) + if surface != "or-image": + cap_ok, cap_note = _probe_chat_capability(picked) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await runner(picked) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +async def main(args: argparse.Namespace) -> int: + print("Loaded global configs:") + print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") + print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") + print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") + print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") + + # Initialize the OpenRouter integration so the catalog is populated + # (this is what main.py does at startup). It's idempotent. + if config.OPENROUTER_INTEGRATION_SETTINGS: + try: + from app.config import initialize_openrouter_integration + + initialize_openrouter_integration() + except Exception as exc: + print(f" WARNING: OpenRouter integration init failed: {exc}") + + print( + f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}" + ) + + report = Report() + if not args.skip_chat: + await probe_chat_configs(report, live=args.live) + if not args.skip_vision: + await probe_vision_configs(report, live=args.live) + if not args.skip_image_gen: + await probe_image_gen_configs(report, live=args.live) + if not args.skip_openrouter: + await probe_openrouter_catalog(report, live=args.live) + + failed = report.render() + return 1 if failed else 0 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--no-live", + dest="live", + action="store_false", + help="Skip live API calls — capability resolver only.", + ) + parser.set_defaults(live=True) + parser.add_argument("--skip-chat", action="store_true") + parser.add_argument("--skip-vision", action="store_true") + parser.add_argument("--skip-image-gen", action="store_true") + parser.add_argument("--skip-openrouter", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + sys.exit(asyncio.run(main(args))) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py new file mode 100644 index 000000000..9b3de2db7 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py @@ -0,0 +1,268 @@ +"""Regression tests for the compiled-agent cache. + +Covers the cache primitive itself (TTL, LRU, in-flight de-duplication, +build-failure non-caching) and the cache-key signature helpers that +``create_surfsense_deep_agent`` relies on. The integration with +``create_surfsense_deep_agent`` is covered separately by the streaming +contract tests; this module focuses on the primitives so a regression +in the cache implementation is caught before it reaches the agent +factory. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + +import pytest + +from app.agents.new_chat.agent_cache import ( + flags_signature, + reload_for_tests, + stable_hash, + system_prompt_hash, + tools_signature, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# stable_hash + signature helpers +# --------------------------------------------------------------------------- + + +def test_stable_hash_is_deterministic_across_calls() -> None: + a = stable_hash("v1", 42, "thread-9", None, ["x", "y"]) + b = stable_hash("v1", 42, "thread-9", None, ["x", "y"]) + assert a == b + + +def test_stable_hash_changes_when_any_part_changes() -> None: + base = stable_hash("v1", 42, "thread-9") + assert stable_hash("v1", 42, "thread-10") != base + assert stable_hash("v2", 42, "thread-9") != base + assert stable_hash("v1", 43, "thread-9") != base + + +def test_tools_signature_keys_on_name_and_description_not_identity() -> None: + """Two tool lists with the same surface must hash identically. + + The cache key MUST NOT change when the underlying ``BaseTool`` + instances are different Python objects (a fresh request constructs + fresh tool instances every time). Hashing on ``(name, description)`` + keeps the cache hot across requests with identical tool surfaces. + """ + + @dataclass + class FakeTool: + name: str + description: str + + tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")] + tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")] + sig_a = tools_signature( + tools_a, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + sig_b = tools_signature( + tools_b, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + assert sig_a == sig_b, "tool order must not affect the signature" + + # Adding a tool rotates the key. + tools_c = [*tools_a, FakeTool("gamma", "does gamma")] + sig_c = tools_signature( + tools_c, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + assert sig_c != sig_a + + +def test_tools_signature_rotates_when_connector_set_changes() -> None: + @dataclass + class FakeTool: + name: str + description: str + + tools = [FakeTool("a", "x")] + base = tools_signature( + tools, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + added = tools_signature( + tools, + available_connectors=["NOTION", "SLACK"], + available_document_types=["FILE"], + ) + assert base != added, "adding a connector must rotate the cache key" + + +def test_flags_signature_changes_when_flag_flips() -> None: + @dataclass(frozen=True) + class Flags: + a: bool = True + b: bool = False + + base = flags_signature(Flags()) + flipped = flags_signature(Flags(b=True)) + assert base != flipped + + +def test_system_prompt_hash_is_stable_and_distinct() -> None: + p1 = "You are a helpful assistant." + p2 = "You are a helpful assistant!" # one-character delta + assert system_prompt_hash(p1) == system_prompt_hash(p1) + assert system_prompt_hash(p1) != system_prompt_hash(p2) + + +# --------------------------------------------------------------------------- +# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cache_hit_returns_same_instance_on_second_call() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + builds = 0 + + async def builder() -> object: + nonlocal builds + builds += 1 + return object() + + a = await cache.get_or_build("k", builder=builder) + b = await cache.get_or_build("k", builder=builder) + assert a is b, "cache must return the SAME object across hits" + assert builds == 1, "builder must run exactly once" + + +@pytest.mark.asyncio +async def test_cache_different_keys_get_different_instances() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("k1", builder=builder) + b = await cache.get_or_build("k2", builder=builder) + assert a is not b + + +@pytest.mark.asyncio +async def test_cache_stale_entries_get_rebuilt() -> None: + # ttl=0 means every read sees the entry as immediately stale. + cache = reload_for_tests(maxsize=8, ttl_seconds=0.0) + builds = 0 + + async def builder() -> object: + nonlocal builds + builds += 1 + return object() + + a = await cache.get_or_build("k", builder=builder) + b = await cache.get_or_build("k", builder=builder) + assert a is not b, "stale entry must rebuild a fresh instance" + assert builds == 2 + + +@pytest.mark.asyncio +async def test_cache_evicts_lru_when_full() -> None: + cache = reload_for_tests(maxsize=2, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("a", builder=builder) + _ = await cache.get_or_build("b", builder=builder) + # Re-touch "a" so "b" is now the LRU victim. + a_again = await cache.get_or_build("a", builder=builder) + assert a_again is a + # Inserting "c" should evict "b" (LRU), not "a". + _ = await cache.get_or_build("c", builder=builder) + assert cache.stats()["size"] == 2 + + # Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild). + a_hit = await cache.get_or_build("a", builder=builder) + assert a_hit is a, "LRU must keep the most-recently-used 'a' entry" + + +@pytest.mark.asyncio +async def test_cache_concurrent_misses_coalesce_to_single_build() -> None: + """Two concurrent get_or_build calls on the same key must share one builder.""" + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + build_started = asyncio.Event() + builds = 0 + + async def slow_builder() -> object: + nonlocal builds + builds += 1 + build_started.set() + # Yield control so the second waiter can race against us. + await asyncio.sleep(0.05) + return object() + + task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder)) + # Wait until the first builder has started, then race a second waiter. + await build_started.wait() + task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder)) + + a, b = await asyncio.gather(task_a, task_b) + assert a is b, "coalesced waiters must observe the same value" + assert builds == 1, "concurrent cold misses must collapse to ONE build" + + +@pytest.mark.asyncio +async def test_cache_does_not_store_failed_builds() -> None: + """A builder that raises must NOT poison the cache. + + The next caller for the same key must run the builder again (not + re-raise the cached exception). + """ + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + attempts = 0 + + async def flaky_builder() -> object: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise RuntimeError("transient") + return object() + + with pytest.raises(RuntimeError, match="transient"): + await cache.get_or_build("k", builder=flaky_builder) + + # Second call must retry — not re-raise the cached exception. + value = await cache.get_or_build("k", builder=flaky_builder) + assert value is not None + assert attempts == 2 + + +@pytest.mark.asyncio +async def test_cache_invalidate_drops_entry() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("k", builder=builder) + assert cache.invalidate("k") is True + b = await cache.get_or_build("k", builder=builder) + assert a is not b, "post-invalidation lookup must rebuild" + + +@pytest.mark.asyncio +async def test_cache_invalidate_prefix_drops_matching_entries() -> None: + cache = reload_for_tests(maxsize=16, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + await cache.get_or_build("user:1:thread:1", builder=builder) + await cache.get_or_build("user:1:thread:2", builder=builder) + await cache.get_or_build("user:2:thread:1", builder=builder) + + removed = cache.invalidate_prefix("user:1:") + assert removed == 2 + assert cache.stats()["size"] == 1 + + # The user:2 entry must still be hot (no rebuild). + survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder) + assert survivor_value is not None diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index 38a70a443..6800be2af 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "SURFSENSE_ENABLE_ACTION_LOG", "SURFSENSE_ENABLE_REVERT_ROUTE", + "SURFSENSE_ENABLE_STREAM_PARITY_V2", "SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_OTEL", + "SURFSENSE_ENABLE_AGENT_CACHE", + "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", ]: monkeypatch.delenv(name, raising=False) -def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None: +def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None: _clear_all(monkeypatch) flags = reload_for_tests() assert isinstance(flags, AgentFeatureFlags) assert flags.disable_new_agent_stack is False - assert flags.any_new_middleware_enabled() is False + assert flags.enable_context_editing is True + assert flags.enable_compaction_v2 is True + assert flags.enable_retry_after is True + assert flags.enable_model_fallback is False + assert flags.enable_model_call_limit is True + assert flags.enable_tool_call_limit is True + assert flags.enable_tool_call_repair is True + assert flags.enable_doom_loop is True + assert flags.enable_permission is True + assert flags.enable_busy_mutex is True + assert flags.enable_llm_tool_selector is False + assert flags.enable_skills is True + assert flags.enable_specialized_subagents is True + assert flags.enable_kb_planner_runnable is True + assert flags.enable_action_log is True + assert flags.enable_revert_route is True + assert flags.enable_stream_parity_v2 is True + assert flags.enable_plugin_loader is False + assert flags.enable_otel is False + # Phase 2: agent cache is now default-on (the prerequisite tool + # ``db_session`` refactor landed). The companion gp-subagent share + # flag stays default-off pending data on cold-miss frequency. + assert flags.enable_agent_cache is True + assert flags.enable_agent_cache_share_gp_subagent is False + assert flags.any_new_middleware_enabled() is True def test_master_kill_switch_overrides_individual_flags( @@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", + "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2", "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", "enable_otel": "SURFSENSE_ENABLE_OTEL", } - # `enable_otel` is intentionally orthogonal — it does NOT count toward - # ``any_new_middleware_enabled`` because OTel is observability-only and - # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement. - counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"} - for attr, env_name in flag_to_env.items(): _clear_all(monkeypatch) - monkeypatch.setenv(env_name, "true") + monkeypatch.setenv(env_name, "false") flags = reload_for_tests() - assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}" - if attr in counts_toward_middleware: - assert flags.any_new_middleware_enabled() is True - else: - assert flags.any_new_middleware_enabled() is False + assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py new file mode 100644 index 000000000..6c323d920 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py @@ -0,0 +1,344 @@ +"""Tests for ``FlattenSystemMessageMiddleware``. + +The middleware exists to defend against Anthropic's "Found 5 cache_control +blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on +the system message and the OpenRouter→Anthropic adapter redistributes +``cache_control`` across all of them. The flattening collapses every +all-text system content list to a single string before the LLM call. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import HumanMessage, SystemMessage + +from app.agents.new_chat.middleware.flatten_system import ( + FlattenSystemMessageMiddleware, + _flatten_text_blocks, + _flattened_request, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# _flatten_text_blocks — pure helper, the heart of the middleware. +# --------------------------------------------------------------------------- + + +class TestFlattenTextBlocks: + def test_joins_text_blocks_with_double_newline(self) -> None: + blocks = [ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + assert ( + _flatten_text_blocks(blocks) + == "\n\n\n\n" + ) + + def test_handles_single_text_block(self) -> None: + blocks = [{"type": "text", "text": "only one"}] + assert _flatten_text_blocks(blocks) == "only one" + + def test_handles_empty_list(self) -> None: + assert _flatten_text_blocks([]) == "" + + def test_passes_through_bare_string_blocks(self) -> None: + # LangChain content can mix bare strings and dict blocks. + blocks = ["raw string", {"type": "text", "text": "dict block"}] + assert _flatten_text_blocks(blocks) == "raw string\n\ndict block" + + def test_returns_none_for_image_block(self) -> None: + # System messages with images are rare — but we never want to + # silently lose the image payload by joining as text. + blocks = [ + {"type": "text", "text": "look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png..."}}, + ] + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_for_non_dict_non_str_block(self) -> None: + blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item] + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_when_text_field_missing(self) -> None: + blocks = [{"type": "text"}] # no ``text`` key + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_when_text_is_not_string(self) -> None: + blocks = [{"type": "text", "text": ["nested", "list"]}] + assert _flatten_text_blocks(blocks) is None + + def test_drops_cache_control_from_inner_blocks(self) -> None: + # The whole point: existing cache_control on inner blocks is + # discarded so LiteLLM's ``cache_control_injection_points`` can + # re-attach exactly one breakpoint after flattening. + blocks = [ + {"type": "text", "text": "first"}, + { + "type": "text", + "text": "second", + "cache_control": {"type": "ephemeral"}, + }, + ] + flattened = _flatten_text_blocks(blocks) + assert flattened == "first\n\nsecond" + assert "cache_control" not in flattened # type: ignore[operator] + + +# --------------------------------------------------------------------------- +# _flattened_request — decides when to override and when to no-op. +# --------------------------------------------------------------------------- + + +def _make_request(system_message: SystemMessage | None) -> Any: + """Build a minimal ModelRequest stub. We only need .system_message + and .override(system_message=...) — the middleware never touches + other fields. + """ + request = MagicMock() + request.system_message = system_message + + def override(**kwargs: Any) -> Any: + new_request = MagicMock() + new_request.system_message = kwargs.get( + "system_message", request.system_message + ) + new_request.messages = kwargs.get("messages", getattr(request, "messages", [])) + new_request.tools = kwargs.get("tools", getattr(request, "tools", [])) + return new_request + + request.override = override + return request + + +class TestFlattenedRequest: + def test_collapses_multi_block_system_to_string(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + ) + request = _make_request(sys) + flattened = _flattened_request(request) + + assert flattened is not None + assert isinstance(flattened.system_message, SystemMessage) + assert flattened.system_message.content == ( + "\n\n\n\n\n\n\n\n" + ) + + def test_no_op_for_string_content(self) -> None: + sys = SystemMessage(content="already a string") + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_no_op_for_single_block_list(self) -> None: + # One block already produces one breakpoint — no need to flatten. + sys = SystemMessage(content=[{"type": "text", "text": "single"}]) + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_no_op_when_system_message_missing(self) -> None: + request = _make_request(None) + assert _flattened_request(request) is None + + def test_no_op_when_list_contains_non_text_block(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + ] + ) + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_preserves_additional_kwargs_and_metadata(self) -> None: + # Defensive: nothing in the current chain sets these on a system + # message, but losing them silently when something does in the + # future would be a regression. ``name`` in particular is the only + # ``additional_kwargs`` field that ChatLiteLLM's + # ``_convert_message_to_dict`` propagates onto the wire. + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ], + additional_kwargs={"name": "surfsense_system", "x": 1}, + response_metadata={"tokens": 42}, + ) + sys.id = "sys-msg-1" + request = _make_request(sys) + + flattened = _flattened_request(request) + assert flattened is not None + assert flattened.system_message.content == "a\n\nb" + assert flattened.system_message.additional_kwargs == { + "name": "surfsense_system", + "x": 1, + } + assert flattened.system_message.response_metadata == {"tokens": 42} + assert flattened.system_message.id == "sys-msg-1" + + def test_idempotent_when_run_twice(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + ) + request = _make_request(sys) + first = _flattened_request(request) + assert first is not None + + # Second pass on the already-flattened request should be a no-op. + # We re-wrap in a request stub since the helper inspects + # ``request.system_message.content``. + second_request = _make_request(first.system_message) + assert _flattened_request(second_request) is None + + +# --------------------------------------------------------------------------- +# Middleware integration — verify the handler sees a flattened request. +# --------------------------------------------------------------------------- + + +class TestMiddlewareWrap: + @pytest.mark.asyncio + async def test_async_passes_flattened_request_to_handler(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "text", "text": "beta"}, + ] + ) + request = _make_request(sys) + captured: dict[str, Any] = {} + + async def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + result = await mw.awrap_model_call(request, handler) + + assert result == "ok" + assert isinstance(captured["request"].system_message, SystemMessage) + assert captured["request"].system_message.content == "alpha\n\nbeta" + + @pytest.mark.asyncio + async def test_async_passes_through_when_already_string(self) -> None: + sys = SystemMessage(content="just a string") + request = _make_request(sys) + captured: dict[str, Any] = {} + + async def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + await mw.awrap_model_call(request, handler) + + # Same request object: no override happened. + assert captured["request"] is request + + def test_sync_passes_flattened_request_to_handler(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "text", "text": "beta"}, + ] + ) + request = _make_request(sys) + captured: dict[str, Any] = {} + + def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + result = mw.wrap_model_call(request, handler) + + assert result == "ok" + assert captured["request"].system_message.content == "alpha\n\nbeta" + + def test_sync_passes_through_when_no_system_message(self) -> None: + request = _make_request(None) + captured: dict[str, Any] = {} + + def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + mw.wrap_model_call(request, handler) + assert captured["request"] is request + + +# --------------------------------------------------------------------------- +# Regression guard — pin the worst-case shape that triggered the +# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the +# downstream cache_control_injection_points can only place 1 breakpoint +# on the system message regardless of provider redistribution quirks. +# --------------------------------------------------------------------------- + + +def test_regression_five_block_system_collapses_to_one_block() -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + ) + request = _make_request(sys) + flattened = _flattened_request(request) + + assert flattened is not None + assert isinstance(flattened.system_message.content, str) + # The exact join doesn't matter for the cache_control accounting — + # only that there is exactly ONE content block when LiteLLM's + # AnthropicCacheControlHook later targets ``role: system``. + assert " None: + # Sanity: the middleware MUST NOT touch user messages — only the + # system message. Multi-block user content is the path that carries + # image attachments and would lose its image_url block on + # accidental flatten. + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + ) + user = HumanMessage( + content=[ + {"type": "text", "text": "look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, + ] + ) + request = _make_request(sys) + request.messages = [user] + + flattened = _flattened_request(request) + assert flattened is not None + # System flattened to string … + assert isinstance(flattened.system_message.content, str) + # … user message is untouched (the helper does not even look at it). + assert flattened.messages == [user] + assert isinstance(user.content, list) + assert len(user.content) == 2 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py index 5b3a03581..4cf53969d 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py @@ -1,4 +1,4 @@ -"""Tests for ``apply_litellm_prompt_caching`` in +r"""Tests for ``apply_litellm_prompt_caching`` in :mod:`app.agents.new_chat.prompt_caching`. The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which @@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to ``litellm.completion(...)``. The tests below pin its public contract: -1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so +1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so savings compound across multi-turn conversations on Anthropic-family - providers. + providers. ``index: 0`` is used (rather than ``role: system``) because + the deepagent stack accumulates multiple ``SystemMessage``\ s in + ``state["messages"]`` and ``role: system`` would tag every one of + them, blowing past Anthropic's 4-block ``cache_control`` cap. 2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic prompt-cache surface is available). @@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None: apply_litellm_prompt_caching(llm) points = llm.model_kwargs["cache_control_injection_points"] - assert {"location": "message", "role": "system"} in points + assert {"location": "message", "index": 0} in points assert {"location": "message", "index": -1} in points assert len(points) == 2 +def test_does_not_inject_role_system_breakpoint() -> None: + """Regression: deliberately AVOID ``role: system`` so we don't tag + every SystemMessage the deepagent ``before_agent`` injectors push + into ``state["messages"]`` (priority, tree, memory, file-intent, + anonymous-doc). Tagging all of them overflows Anthropic's 4-block + ``cache_control`` cap and surfaces as + ``OpenrouterException: A maximum of 4 blocks with cache_control may + be provided. Found N`` 400s. + """ + llm = _FakeLLM() + apply_litellm_prompt_caching(llm) + points = llm.model_kwargs["cache_control_injection_points"] + assert all(p.get("role") != "system" for p in points), ( + f"Expected no role=system breakpoint, got: {points}" + ) + + def test_injection_points_set_for_anthropic_config() -> None: """Anthropic-family configs need the marker — verify it lands.""" cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet") diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 2ca470680..2933a0504 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -475,3 +475,190 @@ class TestKBSearchPlanSchema: ) ) assert plan.is_recency_query is False + + +# ── mentioned_document_ids cross-turn drain ──────────────────────────── + + +class TestKnowledgePriorityMentionDrain: + """Regression tests for the cross-turn ``mentioned_document_ids`` drain. + + The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware` + instance across turns of the same thread. ``mentioned_document_ids`` + can therefore enter the middleware via two paths: + + 1. The constructor closure (``__init__(mentioned_document_ids=...)``) — + seeded by the cache-miss build on turn 1. + 2. ``runtime.context.mentioned_document_ids`` — supplied freshly per + turn by the streaming task. + + Without the drain fix, an empty ``runtime.context.mentioned_document_ids`` + on turn 2 would fall through to the closure (because ``[]`` is falsy in + Python) and replay turn 1's mentions. This class pins down the + correct behaviour: the runtime path is authoritative even when empty, + and the closure is drained the first time the runtime path fires so + no later turn can ever resurrect stale state. + """ + + @staticmethod + def _make_runtime(mention_ids: list[int]): + """Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``.""" + from types import SimpleNamespace + + return SimpleNamespace( + context=SimpleNamespace(mentioned_document_ids=mention_ids), + ) + + @staticmethod + def _planner_llm() -> "FakeLLM": + # Planner returns a stable, non-recency plan so we always land in + # the hybrid-search branch (where ``fetch_mentioned_documents`` is + # invoked alongside the main search). + return FakeLLM( + json.dumps( + { + "optimized_query": "follow up question", + "start_date": None, + "end_date": None, + "is_recency_query": False, + } + ) + ) + + async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch): + """Turn 1 with mentions in BOTH closure and runtime context: the + runtime path wins AND the closure is drained so a future turn + cannot replay it. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[1, 2, 3], + ) + + await middleware.abefore_agent( + {"messages": [HumanMessage(content="what is in those docs?")]}, + runtime=self._make_runtime([1, 2, 3]), + ) + + assert fetched_ids == [[1, 2, 3]], ( + "runtime.context mentions must be the source of truth on turn 1" + ) + assert middleware.mentioned_document_ids == [], ( + "closure must be drained the first time the runtime path fires " + "so no later turn can replay stale mentions" + ) + + async def test_empty_runtime_context_does_not_replay_closure_mentions( + self, monkeypatch + ): + """Regression: turn 2 with NO mentions must not surface turn 1's + mentions from the constructor closure. + + Before the fix, ``if ctx_mentions:`` treated an empty list as + absent and fell through to ``elif self.mentioned_document_ids:``, + replaying turn 1's mentions. This test pins down the corrected + behaviour. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + # Simulate a cached middleware instance whose closure was seeded + # by a previous turn's cache-miss build (mentions=[1,2,3]). + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[1, 2, 3], + ) + + # Turn 2: streaming task supplies an EMPTY mention list (no + # mentions on this follow-up turn). + await middleware.abefore_agent( + {"messages": [HumanMessage(content="what about the next steps?")]}, + runtime=self._make_runtime([]), + ) + + assert fetched_ids == [], ( + "fetch_mentioned_documents must NOT be called when the runtime " + "context says there are no mentions for this turn" + ) + + async def test_legacy_path_fires_only_when_runtime_context_absent( + self, monkeypatch + ): + """Backward-compat: if a caller doesn't supply runtime.context (old + non-streaming code path), the closure-injected mentions are still + honoured exactly once and then drained. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[7, 8], + ) + + # First call: no runtime → legacy path uses the closure. + await middleware.abefore_agent( + {"messages": [HumanMessage(content="initial question")]}, + runtime=None, + ) + # Second call: still no runtime — closure already drained, so no replay. + await middleware.abefore_agent( + {"messages": [HumanMessage(content="follow up")]}, + runtime=None, + ) + + assert fetched_ids == [[7, 8]], ( + "legacy path must honour the closure exactly once and then drain it" + ) + assert middleware.mentioned_document_ids == [] diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py new file mode 100644 index 000000000..c9f18d77d --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py @@ -0,0 +1,110 @@ +"""Unit tests for ``supports_image_input`` derivation on BYOK chat config +endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). + +There is no DB column for ``supports_image_input`` on +``NewLLMConfig`` — the value is resolved at the API boundary by +``derive_supports_image_input`` so the new-chat selector / streaming +task can read the same field shape regardless of source (BYOK vs YAML +vs OpenRouter dynamic). Default-allow on unknown so we don't lock the +user out of their own model choice. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from app.db import LiteLLMProvider +from app.routes import new_llm_config_routes + +pytestmark = pytest.mark.unit + + +def _byok_row( + *, + id_: int, + model_name: str, + base_model: str | None = None, + provider: LiteLLMProvider = LiteLLMProvider.OPENAI, + custom_provider: str | None = None, +) -> object: + """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` + walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. + + ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's + enum validator accepts it — same as the ORM row would carry.""" + return SimpleNamespace( + id=id_, + name=f"BYOK-{id_}", + description=None, + provider=provider, + custom_provider=custom_provider, + model_name=model_name, + api_key="sk-byok", + api_base=None, + litellm_params={"base_model": base_model} if base_model else None, + system_instructions="", + use_default_system_instructions=True, + citations_enabled=True, + created_at=datetime.now(tz=UTC), + search_space_id=42, + user_id=uuid4(), + ) + + +def test_serialize_byok_known_vision_model_resolves_true(): + """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> + True. The serialized row carries that value through to the + ``NewLLMConfigRead`` schema.""" + row = _byok_row(id_=1, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + assert serialized.id == 1 + assert serialized.model_name == "gpt-4o" + + +def test_serialize_byok_unknown_model_default_allows(): + """Unknown / unmapped: default-allow. The streaming-task safety net + is the actual block, and it requires LiteLLM to *explicitly* say + text-only — so a brand new BYOK model should not be pre-judged.""" + row = _byok_row( + id_=2, + model_name="brand-new-model-x9-unmapped", + provider=LiteLLMProvider.CUSTOM, + custom_provider="brand_new_proxy", + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_uses_base_model_when_present(): + """Azure-style: ``model_name`` is the deployment id, ``base_model`` + inside ``litellm_params`` is the canonical sku LiteLLM knows. The + helper must consult ``base_model`` first or unrecognised deployment + ids would shadow the real capability.""" + row = _byok_row( + id_=3, + model_name="my-azure-deployment-id-no-litellm-knows-this", + base_model="gpt-4o", + provider=LiteLLMProvider.AZURE_OPENAI, + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_returns_pydantic_read_model(): + """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so + the schema additions are guaranteed to be present in the API + surface. This guards against a future regression where someone + deletes the augmentation step and falls back to ORM passthrough.""" + from app.schemas import NewLLMConfigRead + + row = _byok_row(id_=4, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py new file mode 100644 index 000000000..2b6c76485 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -0,0 +1,184 @@ +"""Unit tests for ``is_premium`` derivation on the global image-gen and +vision-LLM list endpoints. + +Chat globals (``GET /global-llm-configs``) already emit +``is_premium = (billing_tier == "premium")``. Image and vision did not, +which made the new-chat ``model-selector`` render the Free/Premium badge +on the Chat tab but skip it on the Image and Vision tabs (the selector +keys its badge logic off ``is_premium``). These tests pin parity: + +* YAML free entry → ``is_premium=False`` +* YAML premium entry → ``is_premium=True`` +* OpenRouter dynamic premium entry → ``is_premium=True`` +* Auto stub (always emitted when at least one config is present) + → ``is_premium=False`` +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_IMAGE_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "DALL-E 3", + "provider": "OPENAI", + "model_name": "dall-e-3", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "GPT-Image 1 (premium)", + "provider": "OPENAI", + "model_name": "gpt-image-1", + "api_key": "sk-test", + "billing_tier": "premium", + }, + { + "id": -20_001, + "name": "google/gemini-2.5-flash-image (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +_VISION_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o Vision", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "Claude 3.5 Sonnet (premium)", + "provider": "ANTHROPIC", + "model_name": "claude-3-5-sonnet", + "api_key": "sk-ant-test", + "billing_tier": "premium", + }, + { + "id": -30_001, + "name": "openai/gpt-4o (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +# ============================================================================= +# Image generation +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_emit_is_premium(monkeypatch): + """Each emitted config must carry ``is_premium`` derived server-side + from ``billing_tier``. The Auto stub is always free. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False + ) + + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + # Auto stub is always emitted when at least one global config exists, + # and it must always declare itself free (Auto-mode billing-tier + # surfacing is a separate follow-up). + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + # YAML free entry — ``is_premium=False`` + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + # YAML premium entry — ``is_premium=True`` + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + # OpenRouter dynamic premium entry — same field, same derivation + assert by_id[-20_001]["is_premium"] is True + assert by_id[-20_001]["billing_tier"] == "premium" + + # Every emitted dict (including Auto) must have the field — never missing. + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): + """When there are no global configs at all, the endpoint emits an + empty list (no Auto stub) — Auto mode would have nothing to route to. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + assert payload == [] + + +# ============================================================================= +# Vision LLM +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr( + config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False + ) + + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + assert by_id[-30_001]["is_premium"] is True + assert by_id[-30_001]["billing_tier"] == "premium" + + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py new file mode 100644 index 000000000..b47d9134b --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -0,0 +1,106 @@ +"""Unit tests for ``supports_image_input`` derivation on the chat global +config endpoint (``GET /global-new-llm-configs``). + +Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): + +1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML + loader for operator overrides, or by the OpenRouter integration from + ``architecture.input_modalities``) — wins. +2. ``derive_supports_image_input`` helper — default-allow on unknown + models, only False when LiteLLM / OR modalities are definitive. + +The flag is purely informational at the API boundary. The streaming +task safety net (``is_known_text_only_chat_model``) is the actual block, +and it requires LiteLLM to *explicitly* mark the model as text-only. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o (explicit true)", + "description": "vision-capable, explicit YAML override", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + "supports_image_input": True, + }, + { + "id": -2, + "name": "DeepSeek V3 (explicit false)", + "description": "OpenRouter dynamic — modality-derived false", + "provider": "OPENROUTER", + "model_name": "deepseek/deepseek-v3.2-exp", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "free", + "supports_image_input": False, + }, + { + "id": -10_010, + "name": "Unannotated GPT-4o", + "description": "no flag set — resolver should derive True via LiteLLM", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + # supports_image_input intentionally absent + }, + { + "id": -10_011, + "name": "Unannotated unknown model", + "description": "unmapped — default-allow True", + "provider": "CUSTOM", + "custom_provider": "brand_new_proxy", + "model_name": "brand-new-model-x9", + "api_key": "sk-test", + "billing_tier": "free", + }, +] + + +@pytest.mark.asyncio +async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): + """Each emitted chat config carries ``supports_image_input`` as a + bool. Explicit values win; unannotated entries are resolved via the + helper (default-allow True).""" + from app.config import config + from app.routes import new_llm_config_routes + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) + + payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) + by_id = {c["id"]: c for c in payload} + + # Auto stub: optimistic True so the user can keep Auto selected with + # vision-capable deployments somewhere in the pool. + assert 0 in by_id, "Auto stub should be emitted when configs exist" + assert by_id[0]["supports_image_input"] is True + assert by_id[0]["is_auto_mode"] is True + + # Explicit True is preserved. + assert by_id[-1]["supports_image_input"] is True + + # Explicit False is preserved (the exact failure mode the safety net + # guards against — DeepSeek V3 over OpenRouter would 404 with "No + # endpoints found that support image input"). + assert by_id[-2]["supports_image_input"] is False + + # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. + assert by_id[-10_010]["supports_image_input"] is True + + # Unknown / unmapped model: default-allow rather than pre-judge. + assert by_id[-10_011]["supports_image_input"] is True + + for cfg in payload: + assert "supports_image_input" in cfg, ( + f"supports_image_input missing from {cfg.get('id')}" + ) + assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py new file mode 100644 index 000000000..636b7de31 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -0,0 +1,138 @@ +"""Unit tests for the image-generation route's billing-resolution helper. + +End-to-end "POST /image-generations returns 402" coverage requires the +integration harness (real DB, real auth) and lives in +``tests/integration/document_upload/`` alongside the other quota tests. +This unit test focuses on the new ``_resolve_billing_for_image_gen`` +helper which: + +* Returns ``free`` for Auto mode, even when premium configs exist + (Auto-mode billing-tier surfacing is a follow-up). +* Returns ``free`` for user-owned BYOK configs (positive IDs). +* Returns the global config's ``billing_tier`` for negative IDs. +* Honours the per-config ``quota_reserve_micros`` override when present. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_resolve_billing_for_auto_mode(monkeypatch): + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, # Not consumed on this code path. + config_id=0, # IMAGE_GEN_AUTO_MODE_ID + search_space=search_space, + ) + assert tier == "free" + assert model == "auto" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_premium_global_config(monkeypatch): + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + "quota_reserve_micros": 75_000, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "billing_tier": "free", + }, + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=None) + + # Premium with override. + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-1, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" + assert reserve == 75_000 + + # Free, no override → falls back to default. + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-2, search_space=search_space + ) + assert tier == "free" + # Provider-prefixed model string for OpenRouter. + assert "google/gemini-2.5-flash-image" in model + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_user_owned_byok_is_free(): + """User-owned BYOK configs (positive IDs) cost the user nothing on + our side — they pay the provider directly. Always free. + """ + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=42, search_space=search_space + ) + assert tier == "free" + assert model == "user_byok" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): + """When the request omits ``image_generation_config_id``, the helper + must consult the search space's default — so a search space pinned + to a premium global config still gates new requests by quota. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -7, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + } + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=-7) + ( + tier, + model, + _reserve, + ) = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=None, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py new file mode 100644 index 000000000..fa8819b39 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -0,0 +1,436 @@ +"""Unit tests for ``_resolve_agent_billing_for_search_space``. + +Validates the resolver used by Celery podcast/video tasks to compute +``(owner_user_id, billing_tier, base_model)`` from a search space and its +agent LLM config. The resolver mirrors chat's billing-resolution pattern at +``stream_new_chat.py:2294-2351`` and is the single integration point that +prevents Auto-mode podcast/video from leaking premium credit. + +Coverage: + +* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium + global → returns ``("premium", )``. +* Auto mode + ``thread_id`` set, pin resolves to a negative-id free + global → returns ``("free", )``. +* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config + → always ``"free"``. +* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without + hitting the pin service. +* Negative id (no Auto) → uses ``get_global_llm_config``'s + ``billing_tier``. +* Positive id (user BYOK) → always ``"free"``. +* Search space not found → raises ``ValueError``. +* ``agent_llm_id`` is None → raises ``ValueError``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from uuid import UUID, uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + +class _FakeSession: + """Tiny AsyncSession stub. + + ``responses`` is a list of objects to return from successive + ``execute()`` calls (in order). The resolver makes at most two + ``execute()`` calls (search-space lookup, then optionally NewLLMConfig + lookup), so two queued responses cover the matrix. + """ + + def __init__(self, responses: list): + self._responses = list(responses) + + async def execute(self, _stmt): + if not self._responses: + return _FakeExecResult(None) + return _FakeExecResult(self._responses.pop(0)) + + async def commit(self) -> None: + pass + + +@dataclass +class _FakePinResolution: + resolved_llm_config_id: int + resolved_tier: str = "premium" + from_existing_pin: bool = False + + +def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace( + id=42, + agent_llm_id=agent_llm_id, + user_id=user_id, + ) + + +def _make_byok_config( + *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +) -> SimpleNamespace: + return SimpleNamespace( + id=id_, + model_name=model_name, + litellm_params={"base_model": base_model} if base_model else {}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): + """Auto + thread → pin service resolves to negative-id premium config → + resolver returns ``("premium", )``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + # Mock the pin service to return a concrete premium config id. + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + assert selected_llm_config_id == 0 + assert thread_id == 99 + return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") + + # Mock global config lookup to return a premium entry. + def _fake_get_global(cfg_id): + if cfg_id == -1: + return { + "id": -1, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + return None + + # Lazy imports inside the resolver — patch the *target* modules so the + # imported names resolve to our fakes. + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): + """Auto + thread → pin returns negative-id free config → resolver + returns ``("free", )``. Same path the pin service takes for + out-of-credit users (graceful degradation).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") + + def _fake_get_global(cfg_id): + if cfg_id == -3: + return { + "id": -3, + "model_name": "openrouter/free-model", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/free-model"}, + } + return None + + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/free-model" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): + """Auto + thread → pin returns positive-id BYOK config → resolver + returns ``("free", ...)`` (BYOK is always free per + ``AgentConfig.from_new_llm_config``).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=0, user_id=user_id) + byok_cfg = _make_byok_config( + id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + ) + session = _FakeSession([search_space, byok_cfg]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3-haiku" + + +@pytest.mark.asyncio +async def test_auto_mode_without_thread_id_falls_back_to_free(): + """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking + the pin service. Forward-compat fallback for any future direct-API + entrypoint that doesn't have a chat thread.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): + """If the pin service raises ``ValueError`` (thread missing / + mismatched search space), the resolver should log and return free + rather than killing the whole task.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin(*args, **kwargs): + raise ValueError("thread missing") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_negative_id_premium_global_returns_premium(monkeypatch): + """Explicit negative agent_llm_id → ``get_global_llm_config`` → + return its ``billing_tier``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_negative_id_free_global_returns_free(monkeypatch): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "openrouter/some-free", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/some-free"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/some-free" + + +@pytest.mark.asyncio +async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): + """When the global config has no ``litellm_params.base_model``, the + resolver falls back to ``model_name`` — matching chat's behavior.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "fallback-model", + "billing_tier": "premium", + # No litellm_params. + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + _, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert tier == "premium" + assert base_model == "fallback-model" + + +@pytest.mark.asyncio +async def test_positive_id_byok_is_always_free(): + """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, + regardless of underlying provider tier.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=23, user_id=user_id) + byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_cfg]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3.5-sonnet" + + +@pytest.mark.asyncio +async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): + """If the BYOK config row is missing/deleted but the search space still + points at it, the resolver still returns free (no debit) with an empty + base_model — billable_call's premium path is skipped, no harm done.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "" + + +@pytest.mark.asyncio +async def test_search_space_not_found_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + session = _FakeSession([None]) + + with pytest.raises(ValueError, match="Search space"): + await _resolve_agent_billing_for_search_space(session, search_space_id=999) + + +@pytest.mark.asyncio +async def test_agent_llm_id_none_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + + with pytest.raises(ValueError, match="agent_llm_id"): + await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 49b3621c7..d1af29aeb 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -101,11 +101,116 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): user_id="00000000-0000-0000-0000-000000000001", selected_llm_config_id=0, ) - assert result.resolved_llm_config_id in {-1, -2} + assert result.resolved_llm_config_id == -1 assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id assert session.commit_count == 1 +@pytest.mark.asyncio +async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + "quality_score": 100, + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + "quality_score": 10, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.1", + "api_key": "k1", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "api_key": "k2", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + }, + { + "id": -3, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5.4", + "api_key": "k3", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 100, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "premium" + + @pytest.mark.asyncio async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config @@ -361,12 +466,12 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): ], ) - async def _allowed(*_args, **_kwargs): - return _FakeQuotaResult(allowed=True) + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", - _allowed, + _blocked, ) result = await resolve_or_get_pinned_llm_config_id( diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py new file mode 100644 index 000000000..0e19b80e4 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -0,0 +1,286 @@ +"""Image-aware extension of the Auto-pin resolver. + +When the current chat turn carries an ``image_url`` block, the pin +resolver must: + +1. Filter the candidate pool to vision-capable cfgs so a freshly + selected pin can never be text-only. +2. Treat any existing pin whose capability is False as invalid (force + re-pin), even when it would otherwise be reused as the thread's + stable model. +3. Raise ``ValueError`` (mapped to the friendly + ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming + task) when no vision-capable cfg is available — instead of silently + pinning text-only and 404-ing at the provider. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + clear_healthy, + clear_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _reset_caches(): + clear_runtime_cooldown() + clear_healthy() + yield + clear_runtime_cooldown() + clear_healthy() + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread(*, pinned: int | None = None): + return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned) + + +def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"vision-{id_}", + "api_key": "k", + "billing_tier": tier, + "supports_image_input": True, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"text-{id_}", + "api_key": "k", + "billing_tier": tier, + # Higher quality than the vision cfgs — so a bug that ignores + # the image flag would surface as the resolver picking this one. + "supports_image_input": False, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +async def _premium_allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + +@pytest.mark.asyncio +async def test_image_turn_filters_out_text_only_candidates(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + # The thread should be pinned to the vision cfg even though the + # text-only cfg has a higher quality score. + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): + """An existing text-only pin must be invalidated when the next turn + requires image input. The non-image path would happily reuse it.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_reuses_existing_vision_pin(monkeypatch): + """If the thread is already pinned to a vision-capable cfg, reuse it + — same as the non-image path. Image-aware filtering must not force + spurious re-pins.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-2)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): + """The friendly-error path: no vision-capable cfg in the pool -> raise + ``ValueError`` whose message contains ``vision-capable`` so the + streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _text_only_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + with pytest.raises(ValueError, match="vision-capable"): + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + +@pytest.mark.asyncio +async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): + """Regression guard: the image flag must default False and not affect + a normal text-only turn — text-only cfgs remain selectable.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + + +@pytest.mark.asyncio +async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): + """A YAML cfg that omits ``supports_image_input`` falls through to + ``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o`` + that returns True, so the cfg should be a valid candidate.""" + from app.config import config + + session = _FakeSession(_thread()) + cfg_unannotated_vision = { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-4o", # known vision model in LiteLLM map + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "A", + "quality_score": 80, + # NOTE: no supports_image_input key + } + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + assert result.resolved_llm_config_id == -2 diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py new file mode 100644 index 000000000..c820724ed --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -0,0 +1,559 @@ +"""Unit tests for the ``billable_call`` async context manager. + +Covers the per-call premium-credit lifecycle for image generation and +vision LLM extraction: + +* Free configs bypass reserve/finalize but still write an audit row. +* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the + route layer). +* Successful premium calls reserve, yield the accumulator, then finalize + with the LiteLLM-reported actual cost — and write an audit row. +* Failed premium calls release the reservation so credit isn't leaked. +* All quota DB ops happen inside their OWN ``shielded_async_session``, + isolating them from the caller's transaction (issue A). +""" + +from __future__ import annotations + +import asyncio +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeQuotaResult: + def __init__( + self, + *, + allowed: bool, + used: int = 0, + limit: int = 5_000_000, + remaining: int = 5_000_000, + ) -> None: + self.allowed = allowed + self.used = used + self.limit = limit + self.remaining = remaining + + +class _FakeSession: + """Minimal AsyncSession stub — record commits for assertion.""" + + def __init__(self) -> None: + self.committed = False + self.added: list[Any] = [] + + def add(self, obj: Any) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.committed = True + + async def rollback(self) -> None: + pass + + async def close(self) -> None: + pass + + +@contextlib.asynccontextmanager +async def _fake_shielded_session(): + s = _FakeSession() + _SESSIONS_USED.append(s) + yield s + + +_SESSIONS_USED: list[_FakeSession] = [] + + +def _patch_isolation_layer( + monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None +): + """Wire fake reserve/finalize/release/session helpers.""" + _SESSIONS_USED.clear() + reserve_calls: list[dict[str, Any]] = [] + finalize_calls: list[dict[str, Any]] = [] + release_calls: list[dict[str, Any]] = [] + + async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros): + reserve_calls.append( + { + "user_id": user_id, + "reserve_micros": reserve_micros, + "request_id": request_id, + } + ) + return reserve_result + + async def _fake_finalize( + *, db_session, user_id, request_id, actual_micros, reserved_micros + ): + if finalize_exc is not None: + raise finalize_exc + finalize_calls.append( + { + "user_id": user_id, + "actual_micros": actual_micros, + "reserved_micros": reserved_micros, + } + ) + return finalize_result or _FakeQuotaResult(allowed=True) + + async def _fake_release(*, db_session, user_id, reserved_micros): + release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros}) + + record_calls: list[dict[str, Any]] = [] + + async def _fake_record(session, **kwargs): + record_calls.append(kwargs) + return object() + + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_reserve", + _fake_reserve, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_finalize", + _fake_finalize, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_release", + _fake_release, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.shielded_async_session", + _fake_shielded_session, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _fake_record, + raising=False, + ) + + return { + "reserve": reserve_calls, + "finalize": finalize_calls, + "release": release_calls, + "record": record_calls, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + ) as acc: + # Simulate a captured cost — the accumulator is fed by the LiteLLM + # callback in real life, here we add it manually. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + # Free still audits. + assert len(spies["record"]) == 1 + assert spies["record"][0]["usage_type"] == "image_generation" + assert spies["record"][0]["cost_micros"] == 37_000 + + +@pytest.mark.asyncio +async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=5_000_000, limit=5_000_000, remaining=0 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "image_generation" + assert err.used_micros == 5_000_000 + assert err.limit_micros == 5_000_000 + assert err.remaining_micros == 0 + # Reserve was attempted, but no finalize/release on a denied reserve + # — the reservation never actually held credit. + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + assert spies["release"] == [] + # Denied premium calls do NOT create an audit row (no work happened). + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_success_finalizes_with_actual_cost(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + # LiteLLM callback would normally fill this — simulate $0.04 image. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert spies["reserve"][0]["reserve_micros"] == 50_000 + assert len(spies["finalize"]) == 1 + assert spies["finalize"][0]["actual_micros"] == 40_000 + assert spies["finalize"][0]["reserved_micros"] == 50_000 + assert spies["release"] == [] + # And audit row written with the actual debited cost. + assert spies["record"][0]["cost_micros"] == 40_000 + # Each quota op opened its OWN session — proves session isolation. + assert len(_SESSIONS_USED) >= 3 + # Sessions used should each have committed (or be the audit one which commits). + for _s in _SESSIONS_USED: + # finalize/reserve happen via TokenQuotaService.* which we stub — + # they don't actually call commit on our fake session, but the + # audit session does. We just assert >=1 session committed. + pass + assert any(s.committed for s in _SESSIONS_USED) + + +@pytest.mark.asyncio +async def test_premium_failure_releases_reservation(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _ProviderError(Exception): + pass + + with pytest.raises(_ProviderError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + raise _ProviderError("OpenRouter 503") + + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + # Failure path: release the held reservation. + assert len(spies["release"]) == 1 + assert spies["release"][0]["reserved_micros"] == 50_000 + + +@pytest.mark.asyncio +async def test_premium_uses_estimator_when_no_micros_override(monkeypatch): + """When ``quota_reserve_micros_override`` is None we fall back to + ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``. + Vision LLM calls take this path (token-priced models). + """ + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + captured_estimator_calls: list[dict[str, Any]] = [] + + def _fake_estimate(*, base_model, quota_reserve_tokens): + captured_estimator_calls.append( + {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens} + ) + return 12_345 + + monkeypatch.setattr( + "app.services.billable_calls.estimate_call_reserve_micros", + _fake_estimate, + raising=False, + ) + + user_id = uuid4() + async with billable_call( + user_id=user_id, + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + usage_type="vision_extraction", + ): + pass + + assert captured_estimator_calls == [ + {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000} + ] + assert spies["reserve"][0]["reserve_micros"] == 12_345 + + +@pytest.mark.asyncio +async def test_premium_finalize_failure_propagates_and_releases(monkeypatch): + from app.services.billable_calls import BillingSettlementError, billable_call + + class _FinalizeError(RuntimeError): + pass + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult(allowed=True), + finalize_exc=_FinalizeError("db finalize failed"), + ) + user_id = uuid4() + + with pytest.raises(BillingSettlementError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["release"]) == 1 + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _HangingCommitSession(_FakeSession): + async def commit(self) -> None: + await asyncio.sleep(60) + + @contextlib.asynccontextmanager + async def _hanging_session_factory(): + s = _HangingCommitSession() + _SESSIONS_USED.append(s) + yield s + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + billable_session_factory=_hanging_session_factory, + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["finalize"]) == 1 + assert len(spies["record"]) == 1 + assert spies["release"] == [] + + +@pytest.mark.asyncio +async def test_free_audit_failure_is_best_effort(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + async def _failing_record(_session, **_kwargs): + raise RuntimeError("audit insert failed") + + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _failing_record, + raising=False, + ) + + async with billable_call( + user_id=uuid4(), + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + + +# --------------------------------------------------------------------------- +# Podcast / video-presentation usage_type coverage +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch): + """Free podcast configs must skip reserve/finalize but still emit a + ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we + have full audit coverage of free-tier agent runs.""" + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openrouter/some-free-model", + quota_reserve_micros_override=200_000, + usage_type="podcast_generation", + thread_id=99, + call_details={"podcast_id": 7, "title": "Test Podcast"}, + ) as acc: + # Two transcript LLM calls aggregated into one accumulator. + acc.add( + model="openrouter/some-free-model", + prompt_tokens=1500, + completion_tokens=8000, + total_tokens=9500, + cost_micros=0, + call_kind="chat", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + + assert len(spies["record"]) == 1 + row = spies["record"][0] + assert row["usage_type"] == "podcast_generation" + assert row["thread_id"] is None + assert row["search_space_id"] == 42 + assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + + +@pytest.mark.asyncio +async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): + """Premium video-presentation runs that hit a denied reservation must + raise ``QuotaInsufficientError`` *before* the graph runs and must not + emit an audit row (no work happened).""" + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="gpt-5.4", + quota_reserve_micros_override=1_000_000, + usage_type="video_presentation_generation", + thread_id=99, + call_details={"video_presentation_id": 12, "title": "Test Video"}, + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "video_presentation_generation" + assert err.remaining_micros == 500_000 + assert spies["reserve"][0]["reserve_micros"] == 1_000_000 + assert spies["finalize"] == [] + assert spies["release"] == [] + assert spies["record"] == [] diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py new file mode 100644 index 000000000..9d5fdb190 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -0,0 +1,177 @@ +"""Defense-in-depth: image-gen call sites must not let an empty +``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. + +The bug repro: an OpenRouter image-gen config ships +``api_base=""``. The pre-fix call site in +``image_generation_routes._execute_image_generation`` did +``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which +silently dropped the empty string. LiteLLM then fell back to +``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) +and OpenRouter's ``image_generation/transformation`` appended +``/chat/completions`` to it → 404 ``Resource not found``. + +This test pins the post-fix behaviour: with an empty ``api_base`` in +the config, the call site MUST set ``api_base`` to OpenRouter's public +URL instead of leaving it unset. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): + """The global-config branch (``config_id < 0``) of + ``_execute_image_generation`` must apply the resolver and pin + ``api_base`` to OpenRouter when the config ships an empty string. + """ + from app.routes import image_generation_routes + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", # the original bug shape + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) + + image_gen = MagicMock() + image_gen.image_generation_config_id = cfg["id"] + image_gen.prompt = "test" + image_gen.n = 1 + image_gen.quality = None + image_gen.size = None + image_gen.style = None + image_gen.response_format = None + image_gen.model = None + + search_space = MagicMock() + search_space.image_generation_config_id = cfg["id"] + session = MagicMock() + + with ( + patch.object( + image_generation_routes, + "_get_global_image_gen_config", + return_value=cfg, + ), + patch.object( + image_generation_routes, + "aimage_generation", + side_effect=fake_aimage_generation, + ), + ): + await image_generation_routes._execute_image_generation( + session=session, image_gen=image_gen, search_space=search_space + ) + + # The whole point of the fix: even with empty ``api_base`` in the + # config, we forward OpenRouter's public URL so the call doesn't + # inherit an Azure endpoint. + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +@pytest.mark.asyncio +async def test_generate_image_tool_global_sets_api_base_when_config_empty(): + """Same defense at the agent tool entry point — both surfaces share + the same OpenRouter config payloads.""" + from app.agents.new_chat.tools import generate_image as gi_module + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + response = MagicMock() + response.model_dump.return_value = { + "data": [{"url": "https://example.com/x.png"}] + } + response._hidden_params = {"model": "openrouter/openai/gpt-image-1"} + return response + + search_space = MagicMock() + search_space.id = 1 + search_space.image_generation_config_id = cfg["id"] + + session_cm = AsyncMock() + session = AsyncMock() + session_cm.__aenter__.return_value = session + + scalars = MagicMock() + scalars.first.return_value = search_space + exec_result = MagicMock() + exec_result.scalars.return_value = scalars + session.execute.return_value = exec_result + session.add = MagicMock() + session.commit = AsyncMock() + session.refresh = AsyncMock() + + # ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback. + async def _refresh(obj): + obj.id = 1 + + session.refresh.side_effect = _refresh + + with ( + patch.object(gi_module, "shielded_async_session", return_value=session_cm), + patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object( + gi_module, "aimage_generation", side_effect=fake_aimage_generation + ), + patch.object( + gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0 + ), + ): + tool = gi_module.create_generate_image_tool( + search_space_id=1, db_session=MagicMock() + ) + await tool.ainvoke({"prompt": "a cat", "n": 1}) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +def test_image_gen_router_deployment_sets_api_base_when_config_empty(): + """The Auto-mode router pool must also resolve ``api_base`` when an + OpenRouter config ships an empty string. The deployment dict is fed + straight to ``litellm.Router``, so a missing ``api_base`` would + leak the same way as the direct call sites. + """ + from app.services.image_gen_router_service import ImageGenRouterService + + deployment = ImageGenRouterService._config_to_deployment( + { + "model_name": "openai/gpt-image-1", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 085740032..88fcf2db3 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -214,3 +214,167 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/gpt-4o" in model_names assert "openai/dall-e" not in model_names assert "openai/completion-only" not in model_names + + +# --------------------------------------------------------------------------- +# _generate_image_gen_configs / _generate_vision_llm_configs +# --------------------------------------------------------------------------- + + +def test_generate_image_gen_configs_filters_by_image_output(): + """Only models with ``output_modalities`` containing ``image`` are emitted. + Tool-calling and context filters are intentionally NOT applied — image + generation has nothing to do with tool calls and context windows. + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + # Pure image-gen model (small context, no tools — should still emit). + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Multi-modal: text+image output (should still emit). + { + "id": "google/gemini-2.5-flash-image", + "architecture": {"output_modalities": ["text", "image"]}, + "context_length": 1_000_000, + "pricing": {"prompt": "0.000001", "completion": "0.000004"}, + }, + # Pure text model — must NOT emit. + { + "id": "openai/gpt-4o", + "architecture": {"output_modalities": ["text"]}, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + ] + + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openai/gpt-image-1" in model_names + assert "google/gemini-2.5-flash-image" in model_names + assert "openai/gpt-4o" not in model_names + + # Each config must carry ``billing_tier`` for routing in image_generation_routes. + for c in cfgs: + assert c["billing_tier"] in {"free", "premium"} + assert c["provider"] == "OPENROUTER" + assert c[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't 404 against an inherited Azure endpoint. + assert c["api_base"] == "https://openrouter.ai/api/v1" + + +def test_generate_image_gen_configs_assigns_image_id_offset(): + """Image configs use a different id_offset (-20000) so their negative + IDs don't collide with chat configs (-10000) or vision configs (-30000). + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + } + ] + # Don't pass image_id_offset → use the module default (-20000). + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + assert all(c["id"] < -20_000 + 1 for c in cfgs) + assert all(c["id"] > -29_000_000 for c in cfgs) + + +def test_generate_vision_llm_configs_filters_by_image_input_text_output(): + """Vision LLMs must accept image input AND emit text — pure image-gen + (no text out) and text-only (no image in) models are excluded. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + # GPT-4o: vision LLM (image in, text out) — must emit. + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + # Pure image generator — image *output*, no text out. Must NOT emit. + { + "id": "openai/gpt-image-1", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["image"], + }, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Pure text model (no image in). Must NOT emit. + { + "id": "anthropic/claude-3-haiku", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "context_length": 200_000, + "pricing": {"prompt": "0.000001", "completion": "0.000005"}, + }, + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + names = {c["model_name"] for c in cfgs} + assert names == {"openai/gpt-4o"} + + cfg = cfgs[0] + assert cfg["billing_tier"] == "premium" + # Pricing carried inline so pricing_registration can register vision + # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache + # is cleared. + assert cfg["input_cost_per_token"] == pytest.approx(5e-6) + assert cfg["output_cost_per_token"] == pytest.approx(15e-6) + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't inherit an Azure endpoint. + assert cfg["api_base"] == "https://openrouter.ai/api/v1" + + +def test_generate_vision_llm_configs_drops_chat_only_filters(): + """A small-context vision model that doesn't advertise tool calling is + still a valid vision LLM for "describe this image" prompts. The chat + filters (``supports_tool_calling``, ``has_sufficient_context``) must + NOT be applied to vision emission. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + { + "id": "tiny/vision-mini", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": [], # no tools + "context_length": 4_000, # well below MIN_CONTEXT_LENGTH + "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, + } + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + assert len(cfgs) == 1 + assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py new file mode 100644 index 000000000..e97250ff2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -0,0 +1,447 @@ +"""Pricing registration unit tests. + +The pricing-registration module is what makes ``response_cost`` populate +correctly for OpenRouter dynamic models and operator-defined Azure +deployments — both of which LiteLLM doesn't natively know about. The tests +exercise: + +* The alias generators emit every shape that LiteLLM's cost-callback might + use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``, + ``provider/base_model``, ``provider/model_name``, plus the special + ``azure_openai`` → ``azure`` normalisation). +* ``register_pricing_from_global_configs`` calls ``litellm.register_model`` + with the right alias set and pricing values per provider. +* Configs without a resolvable pair of cost values are skipped — never + registered as zero, since that would override pricing LiteLLM might + already know natively. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Alias generators +# --------------------------------------------------------------------------- + + +def test_openrouter_alias_set_includes_prefixed_and_bare(): + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet") + assert aliases == [ + "openrouter/anthropic/claude-3-5-sonnet", + "anthropic/claude-3-5-sonnet", + ] + + +def test_openrouter_alias_set_dedupes(): + """If the model id is already prefixed with ``openrouter/``, the alias + set must not contain duplicates that would re-register the same key + twice. + """ + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("openrouter/foo") + # The bare and prefixed variants compute to the same string here, so we + # at minimum require uniqueness. + assert len(aliases) == len(set(aliases)) + + +def test_yaml_alias_set_for_azure_openai_normalises_to_azure(): + """``azure_openai`` (our YAML provider slug) must register under + ``azure/`` so the LiteLLM Router's deployment-resolution path + (which uses provider ``azure``) finds the pricing too. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="AZURE_OPENAI", + model_name="gpt-5.4", + base_model="gpt-5.4", + ) + assert "gpt-5.4" in aliases + assert "azure_openai/gpt-5.4" in aliases + assert "azure/gpt-5.4" in aliases + + +def test_yaml_alias_set_distinguishes_model_name_and_base_model(): + """When ``model_name`` differs from ``base_model`` (operator labelled a + deployment), both must appear in the alias set since either may surface + in callbacks depending on the call path. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="OPENAI", + model_name="my-deployment-label", + base_model="gpt-4o", + ) + assert "gpt-4o" in aliases + assert "openai/gpt-4o" in aliases + assert "my-deployment-label" in aliases + assert "openai/my-deployment-label" in aliases + + +def test_yaml_alias_set_omits_provider_prefix_when_provider_blank(): + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="", + model_name="foo", + base_model="bar", + ) + assert "bar" in aliases + assert "foo" in aliases + assert all("/" not in a for a in aliases) + + +# --------------------------------------------------------------------------- +# register_pricing_from_global_configs +# --------------------------------------------------------------------------- + + +class _RegistrationSpy: + """Captures the dicts passed to ``litellm.register_model``. + + Many calls may go through; we just record them all and let tests assert + against the union. + """ + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def __call__(self, payload: dict[str, Any]) -> None: + self.calls.append(payload) + + @property + def all_keys(self) -> set[str]: + keys: set[str] = set() + for payload in self.calls: + keys.update(payload.keys()) + return keys + + +def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy: + spy = _RegistrationSpy() + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + spy, + raising=False, + ) + return spy + + +def _patch_openrouter_pricing( + monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]] +) -> None: + """Pretend the OpenRouter integration is initialised with ``mapping``.""" + + class _Stub: + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + return mapping + + class _StubService: + @classmethod + def is_initialized(cls) -> bool: + return True + + @classmethod + def get_instance(cls) -> _Stub: + return _Stub() + + monkeypatch.setattr( + "app.services.openrouter_integration_service.OpenRouterIntegrationService", + _StubService, + raising=False, + ) + + +def test_openrouter_models_register_under_aliases(monkeypatch): + """An OpenRouter config whose ``model_name`` is in the cached raw + pricing map is registered under both ``openrouter/X`` and bare ``X``. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys + assert "anthropic/claude-3-5-sonnet" in spy.all_keys + # Costs are float-converted from the raw OpenRouter strings. + payload = spy.calls[0] + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "input_cost_per_token" + ] == pytest.approx(3e-6) + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "output_cost_per_token" + ] == pytest.approx(15e-6) + assert ( + payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"] + == "openrouter" + ) + + +def test_yaml_override_registers_under_alias_set(monkeypatch): + """Operator-declared ``input_cost_per_token`` / + ``output_cost_per_token`` on a YAML config registers under every + alias the YAML alias generator produces — including the ``azure/`` + normalisation for ``azure_openai`` providers. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "litellm_params": { + "base_model": "gpt-5.4", + "input_cost_per_token": 2e-6, + "output_cost_per_token": 8e-6, + }, + } + ], + ) + + register_pricing_from_global_configs() + + keys = spy.all_keys + assert "gpt-5.4" in keys + assert "azure_openai/gpt-5.4" in keys + assert "azure/gpt-5.4" in keys + + payload = spy.calls[0] + entry = payload["gpt-5.4"] + assert entry["input_cost_per_token"] == pytest.approx(2e-6) + assert entry["output_cost_per_token"] == pytest.approx(8e-6) + assert entry["litellm_provider"] == "azure" + + +def test_no_override_means_no_registration(monkeypatch): + """A YAML config that *omits* both pricing fields must NOT be registered + — registering as zero would override LiteLLM's native pricing for the + ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's + bill drop to $0. Fail-safe is "skip and warn", not "register zero". + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENAI", + "model_name": "gpt-4o", + "litellm_params": {"base_model": "gpt-4o"}, + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_openrouter_skipped_when_pricing_missing(monkeypatch): + """If the OpenRouter raw-pricing cache doesn't carry an entry for a + configured model (network blip during refresh, model added later, etc.), + we skip it rather than registering zero pricing. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}} + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_register_continues_after_individual_failure(monkeypatch, caplog): + """A single bad ``register_model`` call (e.g. raising LiteLLM error) + must not abort registration of the remaining configs. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"} + successful_calls: list[dict[str, Any]] = [] + + def _maybe_fail(payload: dict[str, Any]) -> None: + if any(k in failing_keys for k in payload): + raise RuntimeError("boom") + successful_calls.append(payload) + + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + _maybe_fail, + raising=False, + ) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + }, + { + "id": 2, + "provider": "OPENAI", + "model_name": "custom-deployment", + "litellm_params": { + "base_model": "custom-deployment", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 2e-6, + }, + }, + ], + ) + + register_pricing_from_global_configs() + + # The good config still registered. + assert any("custom-deployment" in payload for payload in successful_calls) + + +def test_vision_configs_registered_with_chat_shape(monkeypatch): + """``register_pricing_from_global_configs`` walks + ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision + calls (during indexing) bill correctly. Vision configs use the same + chat-shape token prices, but image-gen pricing is intentionally NOT + registered here (handled via ``response_cost`` in LiteLLM). + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, + ) + + # No chat configs — only vision. Proves the vision walk is a separate + # iteration, not piggy-backed on the chat list. + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "billing_tier": "premium", + "input_cost_per_token": 5e-6, + "output_cost_per_token": 15e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/openai/gpt-4o" in spy.all_keys + payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] + assert payload_value["mode"] == "chat" + assert payload_value["litellm_provider"] == "openrouter" + assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) + assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) + + +def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): + """If the OpenRouter pricing cache misses a vision model (different + catalogue surface), the vision walk falls back to inline + ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash", + "billing_tier": "premium", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 4e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py new file mode 100644 index 000000000..12cd0a3d5 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py @@ -0,0 +1,107 @@ +"""Unit tests for the shared ``api_base`` resolver. + +The cascade exists so vision and image-gen call sites can't silently +inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) +when an OpenRouter / Groq / etc. config ships an empty string. See +``provider_api_base`` module docstring for the original repro +(OpenRouter image-gen 404-ing against an Azure endpoint). +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_api_base import ( + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) + +pytestmark = pytest.mark.unit + + +def test_config_value_wins_over_defaults(): + """A non-empty config value is always returned verbatim, even when the + provider has a default — the operator gets the last word.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="https://my-openrouter-mirror.example.com/v1", + ) + assert result == "https://my-openrouter-mirror.example.com/v1" + + +def test_provider_key_default_when_config_missing(): + """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own + base URL — the provider-key map must take precedence over the prefix + map so DeepSeek requests don't go to OpenAI.""" + result = resolve_api_base( + provider="DEEPSEEK", + provider_prefix="openai", + config_api_base=None, + ) + assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_provider_prefix_default_when_no_key_default(): + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=None, + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_unknown_provider_returns_none(): + """When neither map matches we return ``None`` so the caller can let + LiteLLM apply its own provider-integration default (Azure deployment + URL, custom-provider URL, etc.).""" + result = resolve_api_base( + provider="SOMETHING_NEW", + provider_prefix="something_new", + config_api_base=None, + ) + assert result is None + + +def test_empty_string_config_treated_as_missing(): + """The original bug: OpenRouter dynamic configs ship ``api_base=""`` + and downstream call sites use ``if cfg.get("api_base"):`` — empty + strings are falsy in Python but the cascade has to step in anyway.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_whitespace_only_config_treated_as_missing(): + """A config value of ``" "`` is a configuration mistake — treat it + as missing instead of forwarding whitespace to LiteLLM (which would + almost certainly 404).""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=" ", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_provider_case_insensitive(): + """Some call sites pass the provider lowercase (DB enum value), others + uppercase (YAML key). Both must resolve.""" + upper = resolve_api_base( + provider="DEEPSEEK", provider_prefix="openai", config_api_base=None + ) + lower = resolve_api_base( + provider="deepseek", provider_prefix="openai", config_api_base=None + ) + assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_all_inputs_none_returns_none(): + assert ( + resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) + is None + ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py new file mode 100644 index 000000000..aac88977f --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -0,0 +1,244 @@ +"""Unit tests for the shared chat-image capability resolver. + +Two resolvers, two intents: + +- ``derive_supports_image_input`` — best-effort True for the catalog and + selector. Default-allow on unknown / unmapped models. The streaming + task safety net never sees this value directly. + +- ``is_known_text_only_chat_model`` — strict opt-out for the safety net. + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False``. Anything else (missing key, exception, + True) returns False so the request flows through to the provider. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import ( + derive_supports_image_input, + is_known_text_only_chat_model, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — OpenRouter modalities path (authoritative) +# --------------------------------------------------------------------------- + + +def test_or_modalities_with_image_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="openai/gpt-4o", + openrouter_input_modalities=["text", "image"], + ) + is True + ) + + +def test_or_modalities_text_only_returns_false(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="deepseek/deepseek-v3.2-exp", + openrouter_input_modalities=["text"], + ) + is False + ) + + +def test_or_modalities_empty_list_returns_false(): + """OR explicitly publishing an empty modality list is a definitive + 'no inputs at all' signal — treat as False rather than falling back + to LiteLLM.""" + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="weird/empty-modalities", + openrouter_input_modalities=[], + ) + is False + ) + + +def test_or_modalities_none_falls_through_to_litellm(): + """``None`` (missing key) is *not* a definitive signal — fall through + to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + openrouter_input_modalities=None, + ) + is True + ) + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — LiteLLM model-map path +# --------------------------------------------------------------------------- + + +def test_litellm_known_vision_model_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + ) + is True + ) + + +def test_litellm_base_model_wins_over_model_name(): + """Azure-style entries pass model_name=deployment_id and put the + canonical sku in litellm_params.base_model. The resolver must + consult base_model first or the deployment id (which LiteLLM + doesn't know) would shadow the real capability.""" + assert ( + derive_supports_image_input( + provider="AZURE_OPENAI", + model_name="my-azure-deployment-id", + base_model="gpt-4o", + ) + is True + ) + + +def test_litellm_unknown_model_default_allows(): + """Default-allow on unknown — the safety net is the actual block.""" + assert ( + derive_supports_image_input( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is True + ) + + +def test_litellm_known_text_only_returns_false(): + """A model that LiteLLM explicitly knows is text-only resolves to + False even via the catalog resolver. ``deepseek-chat`` (the + DeepSeek-V3 chat sku) is in the map without supports_vision and + LiteLLM's `supports_vision` returns False.""" + # Sanity: confirm the helper's negative path. We use a small model + # known not to support vision per the map. + result = derive_supports_image_input( + provider="DEEPSEEK", + model_name="deepseek-chat", + ) + # We accept either False (LiteLLM said explicit no) or True + # (default-allow if the entry isn't mapped on this version) — the + # invariant is that the resolver never *raises* on a known-text-only + # provider/model. The behaviour-binding assertion lives in + # ``test_is_known_text_only_chat_model_explicit_false`` below. + assert isinstance(result, bool) + + +# --------------------------------------------------------------------------- +# is_known_text_only_chat_model — strict opt-out semantics +# --------------------------------------------------------------------------- + + +def test_is_known_text_only_returns_false_for_vision_model(): + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_false_for_unknown_model(): + """Strict opt-out: missing from the map ≠ text-only. The safety net + must NOT fire for an unmapped model — that's the regression we're + fixing.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is False + ) + + +def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): + """LiteLLM's ``get_model_info`` raises freely on parse errors. The + helper swallows the exception and returns False so the safety net + doesn't fire on a transient lookup failure.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise ValueError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): + """Stub LiteLLM's ``get_model_info`` to return an explicit False so + we exercise the opt-out path deterministically. Using a stub keeps + the test stable across LiteLLM map updates.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is True + ) + + +def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) + + +def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): + """A model entry without ``supports_vision`` at all is treated as + 'unknown' — strict opt-out means False.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"max_input_tokens": 8192} # no supports_vision + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py new file mode 100644 index 000000000..9e35b6f9c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -0,0 +1,157 @@ +"""Unit tests for ``QuotaCheckedVisionLLM``. + +Validates that: + +* Calling ``ainvoke`` routes through ``billable_call`` (premium credit + enforcement) and forwards the inner LLM's response on success. +* The wrapper proxies non-overridden attributes to the inner LLM + (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output`` + still work without quota gating (they're not used in indexing today). +* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper + bubbles it up — the ETL pipeline catches that and falls back to OCR. +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +class _FakeInnerLLM: + """Stand-in for ``langchain_litellm.ChatLiteLLM``.""" + + def __init__(self, response: Any = "OCR'd content") -> None: + self._response = response + self.ainvoke_calls: list[Any] = [] + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + self.ainvoke_calls.append(input) + return self._response + + def some_other_method(self, x: int) -> int: + return x * 2 + + +@contextlib.asynccontextmanager +async def _passthrough_billable_call(**_kwargs): + """Stand-in for billable_call that always allows the call to run.""" + + class _Acc: + total_cost_micros = 0 + total_prompt_tokens = 0 + total_completion_tokens = 0 + grand_total = 0 + calls: list[Any] = [] + + def per_message_summary(self) -> dict[str, dict[str, int]]: + return {} + + yield _Acc() + + +@pytest.mark.asyncio +async def test_ainvoke_routes_through_billable_call(monkeypatch): + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + captured_kwargs: list[dict[str, Any]] = [] + + @contextlib.asynccontextmanager + async def _spy_billable_call(**kwargs): + captured_kwargs.append(kwargs) + async with _passthrough_billable_call() as acc: + yield acc + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _spy_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM(response="A red apple on a white table") + user_id = uuid4() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=user_id, + search_space_id=99, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + result = await wrapper.ainvoke([{"text": "what is this?"}]) + assert result == "A red apple on a white table" + assert len(inner.ainvoke_calls) == 1 + assert len(captured_kwargs) == 1 + bc_kwargs = captured_kwargs[0] + assert bc_kwargs["user_id"] == user_id + assert bc_kwargs["search_space_id"] == 99 + assert bc_kwargs["billing_tier"] == "premium" + assert bc_kwargs["base_model"] == "openai/gpt-4o" + assert bc_kwargs["quota_reserve_tokens"] == 4000 + assert bc_kwargs["usage_type"] == "vision_extraction" + + +@pytest.mark.asyncio +async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): + from app.services.billable_calls import QuotaInsufficientError + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + @contextlib.asynccontextmanager + async def _denying_billable_call(**_kwargs): + raise QuotaInsufficientError( + usage_type="vision_extraction", + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield # unreachable but required for asynccontextmanager type + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _denying_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + with pytest.raises(QuotaInsufficientError): + await wrapper.ainvoke([{"text": "x"}]) + + # Inner LLM never ran on a denied reservation. + assert inner.ainvoke_calls == [] + + +@pytest.mark.asyncio +async def test_proxies_non_overridden_attributes_to_inner(): + """``__getattr__`` forwards anything not on the proxy itself, so any + method we didn't explicitly override (``invoke``, ``astream``, + ``with_structured_output``, etc.) still works — just without quota + gating, which is fine because the indexer only ever calls ainvoke. + """ + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + # ``some_other_method`` is on the inner only. + assert wrapper.some_other_method(7) == 14 diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py new file mode 100644 index 000000000..71fdee1c7 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py @@ -0,0 +1,281 @@ +"""Unit tests for the chat-catalog ``supports_image_input`` capability flag. + +Capability is sourced from two places, in order of preference: + +1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs + (authoritative — OpenRouter publishes per-model modalities directly). +2. LiteLLM's authoritative model map (``litellm.supports_vision``) for + YAML / BYOK configs that don't carry an explicit operator override. + +The catalog default is *True* (conservative-allow): an unknown / unmapped +model is not pre-judged. The streaming-task safety net +(``is_known_text_only_chat_model``) is the only place a False actually +blocks a request — and it requires LiteLLM to *explicitly* mark the model +as text-only. +""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _OPENROUTER_DYNAMIC_MARKER, + _generate_configs, + _supports_image_input, +) + +pytestmark = pytest.mark.unit + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, +} + + +# --------------------------------------------------------------------------- +# _supports_image_input helper (OpenRouter modalities) +# --------------------------------------------------------------------------- + + +def test_supports_image_input_true_for_multimodal(): + assert ( + _supports_image_input( + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + } + ) + is True + ) + + +def test_supports_image_input_false_for_text_only(): + """The exact failure mode the safety net guards against — DeepSeek V3 + is a text-in/text-out model and would 404 if forwarded image_url.""" + assert ( + _supports_image_input( + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + } + ) + is False + ) + + +def test_supports_image_input_false_when_modalities_missing(): + """Defensive: missing architecture is treated as text-only at the + OpenRouter helper level. The wider catalog resolver + (`derive_supports_image_input`) only consults modalities when they + are non-empty, otherwise it falls back to LiteLLM.""" + assert _supports_image_input({"id": "weird/model"}) is False + assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False + assert ( + _supports_image_input( + {"id": "weird/model", "architecture": {"input_modalities": None}} + ) + is False + ) + + +# --------------------------------------------------------------------------- +# _generate_configs threads the flag onto every emitted chat config +# --------------------------------------------------------------------------- + + +def test_generate_configs_emits_supports_image_input(): + raw = [ + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + gpt = by_model["openai/gpt-4o"] + assert gpt["supports_image_input"] is True + assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True + + deepseek = by_model["deepseek/deepseek-v3.2-exp"] + assert deepseek["supports_image_input"] is False + assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True + + +# --------------------------------------------------------------------------- +# YAML loader: defer to derive_supports_image_input on unannotated entries +# --------------------------------------------------------------------------- + + +def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch): + """The regression case: an Azure GPT-5.x YAML entry without a + ``supports_image_input`` override should resolve to True via LiteLLM's + model map (which says ``supports_vision: true``). Previously this + defaulted to False, blocking every image turn for vision-capable + YAML configs.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -2 + name: Azure GPT-4o + provider: AZURE_OPENAI + model_name: gpt-4o + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch): + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: GPT-4o + provider: OPENAI + model_name: gpt-4o + api_key: sk-test + supports_image_input: false +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + # Operator override always wins, even against LiteLLM's True. + assert configs[0]["supports_image_input"] is False + + +def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch): + """Unknown / unmapped model in YAML: default-allow. The streaming + safety net (which requires an explicit-False from LiteLLM) is the + only place a real block happens, so we don't lock the user out of + a freshly added third-party entry the catalog can't introspect.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: Some Brand New Model + provider: CUSTOM + custom_provider: brand_new_proxy + model_name: brand-new-model-x9 + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +# --------------------------------------------------------------------------- +# AgentConfig threads the flag through both YAML and Auto / BYOK +# --------------------------------------------------------------------------- + + +def test_agent_config_from_yaml_explicit_overrides_resolver(): + from app.agents.new_chat.llm_config import AgentConfig + + cfg_text_only = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "Text Only Override", + "provider": "openai", + "model_name": "gpt-4o", # Capable per LiteLLM, but operator says no. + "api_key": "sk-test", + "supports_image_input": False, + } + ) + cfg_explicit_vision = AgentConfig.from_yaml_config( + { + "id": -2, + "name": "GPT-4o", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + "supports_image_input": True, + } + ) + assert cfg_text_only.supports_image_input is False + assert cfg_explicit_vision.supports_image_input is True + + +def test_agent_config_from_yaml_unannotated_uses_resolver(): + """Without an explicit YAML key, AgentConfig defers to the catalog + resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True.""" + from app.agents.new_chat.llm_config import AgentConfig + + cfg = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "GPT-4o (no override)", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + } + ) + assert cfg.supports_image_input is True + + +def test_agent_config_auto_mode_supports_image_input(): + """Auto routes across the pool. We optimistically allow image input + so users can keep their selection on Auto with a vision-capable + deployment somewhere in the pool. The router's own `allowed_fails` + handles non-vision deployments via fallback.""" + from app.agents.new_chat.llm_config import AgentConfig + + auto = AgentConfig.from_auto_mode() + assert auto.supports_image_input is True diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py new file mode 100644 index 000000000..63681828d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py @@ -0,0 +1,515 @@ +"""Cost-based premium quota unit tests. + +Covers the USD-micro behaviour added in migration 140: + +* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all + calls in a turn — used as the debit amount when ``agent_config.is_premium`` + is true, regardless of which underlying model produced each call. This + preserves the prior "premium turn → all calls in turn count" rule from the + token-based system. +* ``estimate_call_reserve_micros`` scales linearly with model pricing, + clamps to a sane floor when pricing is unknown, and respects the + ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry + can't lock the whole balance on one call. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# TurnTokenAccumulator — premium-turn debit semantics +# --------------------------------------------------------------------------- + + +def test_total_cost_micros_sums_premium_and_free_calls(): + """A premium turn that also called a free sub-agent debits the union. + + The plan deliberately preserved the existing "premium turn → all calls + count" behaviour because per-call premium filtering relied on + ``LLMRouterService._premium_model_strings`` which only covers router-pool + deployments. ``total_cost_micros`` therefore must include free-model + calls (whose ``cost_micros`` is typically ``0``) as well as the premium + call's actual provider cost. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + # Premium model (e.g. claude-opus): non-zero cost. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=1200, + completion_tokens=400, + total_tokens=1600, + cost_micros=12_345, + ) + # Free sub-agent (e.g. title-gen on a free model): zero cost. + acc.add( + model="gpt-4o-mini", + prompt_tokens=120, + completion_tokens=20, + total_tokens=140, + cost_micros=0, + ) + # A second premium-priced call within the same turn. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=800, + completion_tokens=200, + total_tokens=1000, + cost_micros=7_500, + ) + + assert acc.total_cost_micros == 12_345 + 0 + 7_500 + # Token totals stay correct so the FE display path still works. + assert acc.grand_total == 1600 + 140 + 1000 + + +def test_total_cost_micros_zero_when_no_calls(): + """An empty accumulator must report zero cost (no division-by-zero, no None).""" + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + assert acc.total_cost_micros == 0 + assert acc.grand_total == 0 + + +def test_per_message_summary_groups_cost_by_model(): + """``per_message_summary`` must accumulate ``cost_micros`` per model so the + SSE ``model_breakdown`` payload reports actual USD spend per provider. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cost_micros=4_000, + ) + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=200, + completion_tokens=100, + total_tokens=300, + cost_micros=8_000, + ) + acc.add( + model="gpt-4o-mini", + prompt_tokens=50, + completion_tokens=10, + total_tokens=60, + cost_micros=200, + ) + + summary = acc.per_message_summary() + assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000 + assert summary["claude-3-5-sonnet"]["total_tokens"] == 450 + assert summary["gpt-4o-mini"]["cost_micros"] == 200 + + +def test_serialized_calls_includes_cost_micros(): + """``serialized_calls`` is what flows into the SSE ``call_details`` + payload; cost_micros must be present on each entry so the FE message-info + dropdown can render per-call USD. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="m", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=42, + ) + serialized = acc.serialized_calls() + assert serialized == [ + { + "model": "m", + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + "cost_micros": 42, + "call_kind": "chat", + } + ] + + +# --------------------------------------------------------------------------- +# estimate_call_reserve_micros — sizing and clamping +# --------------------------------------------------------------------------- + + +def test_reserve_returns_floor_when_model_unknown(monkeypatch): + """If LiteLLM doesn't know the model, ``get_model_info`` raises and the + helper falls back to the 100-micro floor — small enough that a user with + $0.0001 left can still send a tiny request, but non-zero so we still gate + against an empty balance. + """ + import litellm + + from app.services import token_quota_service + + def _raise(_name): + raise KeyError("unknown") + + monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="nonexistent-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + assert micros == 100 + + +def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch): + """LiteLLM may *return* a model with both cost-per-token fields at 0 + (pricing not yet registered). The helper must not multiply 0 x tokens + and end up reserving 0 — it must clamp to the floor. + """ + import litellm + + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0}, + raising=False, + ) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="some-pending-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + + +def test_reserve_scales_with_model_cost(monkeypatch): + """Claude-Opus-priced model with 4000 reserve_tokens reserves + ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to + some small artificial cap — that was the bug the plan called out. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 15e-6, + "output_cost_per_token": 75e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="claude-3-opus", + quota_reserve_tokens=4000, + ) + # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros. + assert micros == 360_000 + + +def test_reserve_clamps_to_max_ceiling(monkeypatch): + """A misconfigured "$1000 / M" model with 4000 reserve_tokens would + nominally compute to $4 = 4_000_000 micros. The ceiling + ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry + can't lock the user's whole balance on one call. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-3, + "output_cost_per_token": 0, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="oops-misconfigured", + quota_reserve_tokens=4000, + ) + assert micros == 1_000_000 + + +def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch): + """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or + zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL`` + so anonymous-style configs still reserve the operator-tunable default. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-6, + "output_cost_per_token": 1e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=None + ) + == 4000 + ) + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=0 + ) + == 4000 + ) + + +# --------------------------------------------------------------------------- +# TokenTrackingCallback — image vs chat usage shape +# --------------------------------------------------------------------------- + + +class _FakeImageUsage: + """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape).""" + + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: int | None = None, + ) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + if total_tokens is not None: + self.total_tokens = total_tokens + + +class _FakeImageResponse: + """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's + ``type(...).__name__`` probe routes to the image branch. + """ + + def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None): + self.usage = usage + if response_cost is not None: + self._hidden_params = {"response_cost": response_cost} + + +# Re-tag the helper class as ``ImageResponse`` for the type-name probe in +# the callback. We can't simply name the class ``ImageResponse`` because +# the test runner sometimes imports test modules in surprising ways and +# we want to be explicit. +_FakeImageResponse.__name__ = "ImageResponse" + + +class _FakeChatUsage: + def __init__(self, prompt: int, completion: int): + self.prompt_tokens = prompt + self.completion_tokens = completion + self.total_tokens = prompt + completion + + +class _FakeChatResponse: + def __init__(self, usage: _FakeChatUsage): + self.usage = usage + + +@pytest.mark.asyncio +async def test_callback_reads_image_usage_input_output_tokens(): + """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens`` + for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT + prompt_tokens/completion_tokens which is the chat shape. + """ + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50), + response_cost=0.04, # $0.04 per image + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04}, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 42 + assert call.completion_tokens == 8 + assert call.total_tokens == 50 + # 0.04 USD = 40_000 micros + assert call.cost_micros == 40_000 + assert call.call_kind == "image_generation" + + +@pytest.mark.asyncio +async def test_callback_chat_path_unchanged(): + """Chat responses must still read prompt_tokens/completion_tokens.""" + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30)) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={ + "model": "openrouter/anthropic/claude-3-5-sonnet", + "response_cost": 0.0036, + }, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 120 + assert call.completion_tokens == 30 + assert call.total_tokens == 150 + assert call.cost_micros == 3_600 + assert call.call_kind == "chat" + + +@pytest.mark.asyncio +async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch): + """When OpenRouter omits ``usage.cost`` LiteLLM's + ``default_image_cost_calculator`` raises. The defensive image branch in + ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is + chat-shaped and would raise too) — it returns 0 with a WARNING log. + """ + import litellm + + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + # Force completion_cost to raise the same way OpenRouter image-gen fails. + def _boom(*_args, **_kwargs): + raise ValueError("model_cost: missing entry for openrouter image model") + + monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False) + + # And make sure cost_per_token is NEVER called for the image path — + # if it were, our ``is_image=True`` branch is broken. + cost_per_token_calls: list = [] + + def _record_cost_per_token(**kwargs): + cost_per_token_calls.append(kwargs) + return (0.0, 0.0) + + monkeypatch.setattr( + litellm, "cost_per_token", _record_cost_per_token, raising=False + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=7, output_tokens=0) + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openrouter/google/gemini-2.5-flash-image"}, + response_obj=response, + start_time=None, + end_time=None, + ) + + assert len(acc.calls) == 1 + assert acc.calls[0].cost_micros == 0 + assert acc.calls[0].call_kind == "image_generation" + # The image branch must short-circuit before cost_per_token. + assert cost_per_token_calls == [] + + +# --------------------------------------------------------------------------- +# scoped_turn — ContextVar reset semantics (issue B) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scoped_turn_restores_outer_accumulator(): + """``scoped_turn`` must restore the previous ContextVar value on exit + so a per-call wrapper inside an outer chat turn doesn't leak its + accumulator outward (which would cause double-debit at chat-turn exit). + """ + from app.services.token_tracking_service import ( + get_current_accumulator, + scoped_turn, + start_turn, + ) + + outer = start_turn() + assert get_current_accumulator() is outer + + async with scoped_turn() as inner: + assert get_current_accumulator() is inner + assert inner is not outer + inner.add( + model="x", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=5, + ) + + # After exit the outer accumulator is restored unchanged. + assert get_current_accumulator() is outer + assert outer.total_cost_micros == 0 + assert len(outer.calls) == 0 + # The inner accumulator captured the call but didn't bleed into outer. + assert inner.total_cost_micros == 5 + + +@pytest.mark.asyncio +async def test_scoped_turn_resets_to_none_when_no_outer(): + """Running ``scoped_turn`` outside any chat turn (e.g. a background + indexing job) must leave the ContextVar at ``None`` on exit so the + next *unrelated* request starts clean. + """ + from app.services.token_tracking_service import ( + _turn_accumulator, + get_current_accumulator, + scoped_turn, + ) + + # ContextVar default is None for a fresh test isolated context. We + # simulate "no outer" explicitly to be robust against test order. + token = _turn_accumulator.set(None) + try: + assert get_current_accumulator() is None + async with scoped_turn() as acc: + assert get_current_accumulator() is acc + assert get_current_accumulator() is None + finally: + _turn_accumulator.reset(token) diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py new file mode 100644 index 000000000..b8ba9d80c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -0,0 +1,89 @@ +"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` +defaults from ``litellm.api_base`` either. + +Vision shares the same shape as image-gen — global YAML / OpenRouter +dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` +call sites would silently drop the empty string and inherit +``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on +construction so we test the kwargs we hand to it instead. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_get_vision_llm_global_openrouter_sets_api_base(): + """Global negative-ID branch: an OpenRouter vision config with + ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with + ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, + never silently absent.""" + from app.services import llm_service + + cfg = { + "id": -30_001, + "name": "GPT-4o Vision (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + "billing_tier": "free", + } + + search_space = MagicMock() + search_space.id = 1 + search_space.user_id = "user-x" + search_space.vision_llm_config_id = cfg["id"] + + session = AsyncMock() + scalars = MagicMock() + scalars.first.return_value = search_space + result = MagicMock() + result.scalars.return_value = scalars + session.execute.return_value = result + + captured: dict = {} + + class FakeSanitized: + def __init__(self, **kwargs): + captured.update(kwargs) + + with ( + patch( + "app.services.vision_llm_router_service.get_global_vision_llm_config", + return_value=cfg, + ), + patch( + "app.agents.new_chat.llm_config.SanitizedChatLiteLLM", + new=FakeSanitized, + ), + ): + await llm_service.get_vision_llm(session=session, search_space_id=1) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-4o" + + +def test_vision_router_deployment_sets_api_base_when_config_empty(): + """Auto-mode vision router: deployments are fed to ``litellm.Router``, + so the resolver has to apply at deployment construction time too.""" + from app.services.vision_llm_router_service import VisionLLMRouterService + + deployment = VisionLLMRouterService._config_to_deployment( + { + "model_name": "openai/gpt-4o", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index 9258d5cfe..60750396c 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -51,22 +51,34 @@ class _FakeToolMessage: tool_call_id: str | None = None +@dataclass +class _FakeInterrupt: + value: dict[str, Any] + + +@dataclass +class _FakeTask: + interrupts: tuple[_FakeInterrupt, ...] = () + + class _FakeAgentState: """Stand-in for ``StateSnapshot`` returned by ``aget_state``.""" - def __init__(self) -> None: + def __init__(self, tasks: list[Any] | None = None) -> None: # Empty values keeps the cloud-fallback safety-net branch a no-op, - # and an empty ``tasks`` list keeps the post-stream interrupt - # check a no-op too. + # and empty ``tasks`` keep the post-stream interrupt check a no-op too. self.values: dict[str, Any] = {} - self.tasks: list[Any] = [] + self.tasks: list[Any] = tasks or [] class _FakeAgent: """Replays a list of ``astream_events`` events.""" - def __init__(self, events: list[dict[str, Any]]) -> None: + def __init__( + self, events: list[dict[str, Any]], state: _FakeAgentState | None = None + ) -> None: self._events = events + self._state = state or _FakeAgentState() async def astream_events( # type: ignore[no-untyped-def] self, _input_data: Any, *, config: dict[str, Any], version: str @@ -79,7 +91,7 @@ class _FakeAgent: # Called once after astream_events drains so the cloud-fallback # safety net can inspect staged filesystem work. The fake stays # empty so the safety net is a no-op. - return _FakeAgentState() + return self._state def _model_stream( @@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: ) -async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]: +async def _drain( + events: list[dict[str, Any]], state: _FakeAgentState | None = None +) -> list[dict[str, Any]]: """Run ``_stream_agent_events`` against a fake agent and return the SSE payloads (parsed JSON) it yielded. """ - agent = _FakeAgent(events) + agent = _FakeAgent(events, state=state) service = VercelStreamingService() result = StreamResult() config = {"configurable": {"thread_id": "test-thread"}} @@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id( assert len(starts) == 1 assert starts[0]["toolCallId"].startswith("call_run-1") assert starts[0]["langchainToolCallId"] == "lc-orphan" + + +@pytest.mark.asyncio +async def test_interrupt_request_uses_task_that_contains_interrupt( + parity_v2_on: None, +) -> None: + interrupt_payload = { + "type": "calendar_event_create", + "action": { + "tool": "create_calendar_event", + "params": {"summary": "mom bday"}, + }, + "context": {}, + } + state = _FakeAgentState( + tasks=[ + _FakeTask(interrupts=()), + _FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)), + ] + ) + + payloads = await _drain([], state=state) + + interrupts = _of_type(payloads, "data-interrupt-request") + assert len(interrupts) == 1 + assert ( + interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event" + ) diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py new file mode 100644 index 000000000..a5bb3f58a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -0,0 +1,318 @@ +"""Regression tests for ``run_async_celery_task``. + +These tests pin down the production bug observed on 2026-05-02 where +the video-presentation Celery task hung at ``[billable_call] finalize`` +because the shared ``app.db.engine`` had pooled asyncpg connections +bound to a *previous* task's now-closed event loop. Reusing such a +connection on a fresh loop crashes inside ``pool_pre_ping`` with:: + + AttributeError: 'NoneType' object has no attribute 'send' + +(the proactor is None because the loop is gone) and can hang forever +inside the asyncpg ``Connection._cancel`` cleanup coroutine. + +The fix is ``run_async_celery_task``: a small helper that runs every +async celery task body inside a fresh event loop and disposes the +shared engine pool both before (defends against a previous task that +crashed) and after (releases connections we opened on this loop). + +Tests here exercise the helper with a stub engine that records +``dispose()`` calls and panics if a coroutine produced by one loop is +awaited on another — mirroring the real asyncpg behaviour. +""" + +from __future__ import annotations + +import asyncio +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from unittest.mock import patch + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Stub engine that emulates the asyncpg-on-stale-loop crash +# --------------------------------------------------------------------------- + + +class _StaleLoopEngine: + """Tiny stand-in for ``app.db.engine`` that tracks dispose() calls. + + ``dispose()`` is async (matches ``AsyncEngine.dispose``) and records + the running event loop id so tests can assert it ran on *each* + fresh loop. + """ + + def __init__(self) -> None: + self.dispose_loop_ids: list[int] = [] + + async def dispose(self) -> None: + loop = asyncio.get_running_loop() + self.dispose_loop_ids.append(id(loop)) + + +@contextmanager +def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]: + """Patch ``from app.db import engine as shared_engine`` lookup. + + The helper imports lazily inside the function body, so we have to + patch the attribute on the already-loaded ``app.db`` module. + """ + import app.db as app_db + + original = getattr(app_db, "engine", None) + app_db.engine = stub # type: ignore[attr-defined] + try: + yield + finally: + if original is None: + with pytest.raises(AttributeError): + _ = app_db.engine + else: + app_db.engine = original # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_runner_returns_value_and_disposes_engine_around_call() -> None: + """Happy path: the coroutine result is returned, and the shared + engine is disposed both before and after the task body runs. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> str: + # Engine should already have been disposed once before we run. + assert len(stub.dispose_loop_ids) == 1 + return "ok" + + with _patch_shared_engine(stub): + result = run_async_celery_task(_body) + + assert result == "ok" + # Once before the body, once after (in finally). + assert len(stub.dispose_loop_ids) == 2 + # Both disposes ran on the SAME (fresh) loop the task body used. + assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1] + + +def test_runner_creates_fresh_loop_per_invocation() -> None: + """Each call must spin its own loop. Without this guarantee a + previous task's loop would be reused and the asyncpg-stale-loop + crash would never be avoided. + """ + import app.tasks.celery_tasks as celery_tasks_pkg + + stub = _StaleLoopEngine() + new_loop_calls = 0 + closed_loops: list[bool] = [] + + real_new_event_loop = asyncio.new_event_loop + + def _counting_new_loop() -> asyncio.AbstractEventLoop: + nonlocal new_loop_calls + new_loop_calls += 1 + loop = real_new_event_loop() + # Hook close() so we can verify each loop was closed properly + # before the next one was created. + original_close = loop.close + + def _tracked_close() -> None: + closed_loops.append(True) + original_close() + + loop.close = _tracked_close # type: ignore[method-assign] + return loop + + async def _body() -> None: + # Loop is alive and current at body execution time. + running = asyncio.get_running_loop() + assert not running.is_closed() + + with ( + _patch_shared_engine(stub), + patch.object(asyncio, "new_event_loop", _counting_new_loop), + ): + for _ in range(3): + celery_tasks_pkg.run_async_celery_task(_body) + + assert new_loop_calls == 3 + assert closed_loops == [True, True, True] + # Each invocation disposed twice (before + after). + assert len(stub.dispose_loop_ids) == 6 + + +def test_runner_disposes_engine_even_when_body_raises() -> None: + """Cleanup MUST run on the failure path too — otherwise stale + connections leak into the next task and cause the original hang. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + class _BoomError(RuntimeError): + pass + + async def _body() -> None: + raise _BoomError("kaboom") + + with _patch_shared_engine(stub), pytest.raises(_BoomError): + run_async_celery_task(_body) + + assert len(stub.dispose_loop_ids) == 2 # before + after still ran + + +def test_runner_swallows_dispose_errors() -> None: + """A flaky engine.dispose() must NEVER take down a celery task. + + Production scenario: the very first dispose (before the body runs) + might hit a partially-initialised engine; the helper logs and + moves on. The task body still runs; the result is still returned. + """ + from app.tasks.celery_tasks import run_async_celery_task + + class _AngryEngine: + def __init__(self) -> None: + self.calls = 0 + + async def dispose(self) -> None: + self.calls += 1 + raise RuntimeError("dispose() blew up") + + stub = _AngryEngine() + + async def _body() -> int: + return 42 + + with _patch_shared_engine(stub): + assert run_async_celery_task(_body) == 42 + + assert stub.calls == 2 # before + after both attempted + + +def test_runner_propagates_value_from_async_body() -> None: + """Sanity: pass-through of any pickleable celery return value.""" + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> dict[str, object]: + return {"status": "ready", "video_presentation_id": 19} + + with _patch_shared_engine(stub): + out = run_async_celery_task(_body) + + assert out == {"status": "ready", "video_presentation_id": 19} + + +def test_video_presentation_task_uses_runner_helper() -> None: + """Defence-in-depth: confirm the celery task module imports + ``run_async_celery_task``. If a future refactor inlines a + ``loop = asyncio.new_event_loop(); ... loop.close()`` block again, + the original hang will return. + """ + # The module's task body should not contain a manual new_event_loop + # call — that's exactly what the helper exists to centralise. + import inspect + + from app.tasks.celery_tasks import video_presentation_tasks + + src = inspect.getsource(video_presentation_tasks) + assert "run_async_celery_task" in src, ( + "video_presentation_tasks.py must use run_async_celery_task; " + "manual asyncio.new_event_loop() in a celery task hangs on the " + "shared SQLAlchemy pool when reused across tasks." + ) + assert "asyncio.new_event_loop" not in src, ( + "video_presentation_tasks.py contains a raw asyncio.new_event_loop " + "call — route every async task through run_async_celery_task to " + "avoid the stale-pool hang." + ) + + +def test_podcast_task_uses_runner_helper() -> None: + """Symmetric assertion for the podcast task — same root cause, same + fix, same regression risk. + """ + import inspect + + from app.tasks.celery_tasks import podcast_tasks + + src = inspect.getsource(podcast_tasks) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src + + +def test_runner_runs_shutdown_asyncgens_before_close() -> None: + """If the task body created any async generators that didn't get + fully iterated, we must still call ``loop.shutdown_asyncgens()`` + before closing — otherwise we leak event-loop bound resources + that re-emerge as ``RuntimeError: Event loop is closed`` later. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _agen(): + try: + yield 1 + yield 2 + finally: + pass + + async def _body() -> None: + # Iterate the agen partially, then leave it dangling — exactly + # the situation shutdown_asyncgens() is designed to clean up. + async for v in _agen(): + if v == 1: + break + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # By the time the helper returns, garbage collection + shutdown_asyncgens + # should have ensured no live async-gen references remain. We don't + # assert agen.closed directly (it depends on GC ordering); the real + # contract is "no warnings, no event-loop-closed errors". A successful + # second invocation proves the loop was cleaned up properly. + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # Force a GC pass to surface any 'coroutine was never awaited' + # warnings that would indicate the cleanup is broken. + gc.collect() + + +def test_runner_uses_proactor_loop_on_windows() -> None: + """On Windows the celery worker preselects a Proactor policy so + subprocess (ffmpeg) calls work. The helper must not silently fall + back to a Selector loop and re-break video/podcast generation. + """ + if not sys.platform.startswith("win"): + pytest.skip("Windows-specific event-loop policy assertion") + + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + # Mirror the policy set at the top of every Windows celery task. + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + observed: list[str] = [] + + async def _body() -> None: + observed.append(type(asyncio.get_running_loop()).__name__) + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + assert observed == ["ProactorEventLoop"] diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py new file mode 100644 index 000000000..699297df1 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -0,0 +1,388 @@ +"""Unit tests for podcast Celery task billing integration. + +Validates ``_generate_content_podcast`` correctly wraps +``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the +search-space owner's billing decision, and degrades cleanly when the +resolver fails or premium credit is exhausted. + +Coverage: + +* Happy-path free config: resolver → ``billable_call`` enters with + ``usage_type='podcast_generation'`` and the configured reserve override, + graph runs, podcast row flips to ``READY``. +* Happy-path premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → + graph is *not* invoked, podcast row flips to ``FAILED``, return dict + carries ``reason='premium_quota_exhausted'``. +* Resolver failure: ``ValueError`` from the resolver → podcast row flips + to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, podcast): + self._podcast = podcast + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._podcast) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: + """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily + inside helpers keeps this fixture cheap.""" + return SimpleNamespace( + id=podcast_id, + title="Test Podcast", + thread_id=thread_id, + status=None, + podcast_transcript=None, + file_location=None, + ) + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + """Stand-in for ``billable_call`` that records its kwargs and yields a + no-op accumulator-shaped object.""" + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover — for grammar only + + +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + """Happy path: free billing tier still wraps the graph call so the + audit row is recorded. Verifies kwargs threading.""" + from app.config import config as app_config + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=7, thread_id=99) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 555 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return { + "podcast_transcript": [ + SimpleNamespace(speaker_id=0, dialog="Hi"), + SimpleNamespace(speaker_id=1, dialog="Hello"), + ], + "final_podcast_file_path": "/tmp/podcast.wav", + } + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hello world", + search_space_id=555, + user_prompt="make it short", + ) + + assert result["status"] == "ready" + assert result["podcast_id"] == 7 + assert podcast.status == PodcastStatus.READY + assert podcast.file_location == "/tmp/podcast.wav" + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 555 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "podcast_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS + ) + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "podcast_id": 7, + "title": "Test Podcast", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + """Premium resolution flows through to ``billable_call`` so the + reserve/finalize path triggers.""" + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast() + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): + """When ``billable_call`` denies the reservation, the graph never + runs and the podcast row flips to FAILED with the documented reason + code.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=8) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=8, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 8, + "reason": "premium_quota_exhausted", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] # Graph never ran on denied reservation. + + +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=10) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr( + podcast_tasks, "billable_call", _settlement_failing_billable_call + ) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=10, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 10, + "reason": "billing_settlement_failed", + } + assert podcast.status == PodcastStatus.FAILED + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_podcast_failed(monkeypatch): + """If the resolver raises (e.g. search-space deleted), the task fails + cleanly without invoking the graph.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=9) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 555 not found") + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=9, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 9, + "reason": "billing_resolution_failed", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py new file mode 100644 index 000000000..792d059b0 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -0,0 +1,119 @@ +"""Predicate-level test for the chat streaming safety net. + +The safety net in ``stream_new_chat`` rejects an image turn early with +a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the +selected model is *known* to be text-only. The earlier round of this +work used a strict opt-in flag (``supports_image_input`` defaulting to +False on every YAML entry) which blocked vision-capable Azure GPT-5.x +deployments — this is the regression we're fixing. + +The new predicate is :func:`is_known_text_only_chat_model`, which +returns True only when LiteLLM's authoritative model map *explicitly* +sets ``supports_vision=False``. Anything else (vision True, missing +key, exception) returns False so the request flows through to the +provider. + +We exercise the predicate directly here rather than driving the full +``stream_new_chat`` generator — covering the gate in isolation keeps +the test focused on the regression while the generator's wider behavior +is exercised by the integration suite. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import is_known_text_only_chat_model + +pytestmark = pytest.mark.unit + + +def test_safety_net_does_not_fire_for_azure_gpt_4o(): + """Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is + vision-capable per LiteLLM's model map. The previous round's + blanket-False default blocked it; the new predicate must NOT mark + it text-only.""" + assert ( + is_known_text_only_chat_model( + provider="AZURE_OPENAI", + model_name="my-azure-deployment", + base_model="gpt-4o", + ) + is False + ) + + +def test_safety_net_does_not_fire_for_unknown_model(): + """Default-pass on unknown — the safety net only blocks definitive + text-only confirmations. A freshly added third-party model that + LiteLLM doesn't know about must flow through to the provider.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + custom_provider="brand_new_proxy", + model_name="brand-new-model-x9", + ) + is False + ) + + +def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): + """Transient ``litellm.get_model_info`` exception ≠ block. The + helper swallows the error and treats it as 'unknown' → False.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise RuntimeError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_safety_net_fires_only_on_explicit_false(monkeypatch): + """Stub LiteLLM to assert the only path that returns True is the + explicit ``supports_vision=False`` case. Anything else (True, + None, missing key) returns False from the predicate.""" + import app.services.provider_capabilities as pc + + def _info_explicit_false(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="text-only-stub", + ) + is True + ) + + def _info_true(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="vision-stub", + ) + is False + ) + + def _info_missing(**_kwargs): + return {"max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="missing-key-stub", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py new file mode 100644 index 000000000..423b64ddb --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -0,0 +1,398 @@ +"""Unit tests for video-presentation Celery task billing integration. + +Mirrors ``test_podcast_billing.py`` for the video-presentation task. +Validates the same wrap-graph-in-billable_call pattern and ensures the +larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is +threaded through. + +Coverage: + +* Free config: graph runs, ``billable_call`` invoked with the video + reserve override. +* Premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: graph not invoked, row → FAILED, reason code surfaced. +* Resolver failure: row → FAILED with ``billing_resolution_failed``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, video): + self._video = video + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._video) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace: + return SimpleNamespace( + id=video_id, + title="Test Presentation", + thread_id=thread_id, + status=None, + slides=None, + scene_codes=None, + ) + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover + + +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + from app.config import config as app_config + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=11, thread_id=99) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 777 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result["status"] == "ready" + assert result["video_presentation_id"] == 11 + assert video.status == VideoPresentationStatus.READY + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 777 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "video_presentation_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS + ) + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "video_presentation_id": 11, + "title": "Test Presentation", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video() + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=12) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, "billable_call", _denying_billable_call + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=12, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 12, + "reason": "premium_quota_exhausted", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] + + +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=14) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, + "billable_call", + _settlement_failing_billable_call, + ) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=14, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 14, + "reason": "billing_settlement_failed", + } + assert video.status == VideoPresentationStatus.FAILED + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=13) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 777 not found") + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _failing_resolver, + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=13, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 13, + "reason": "billing_resolution_failed", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index cc8157464..64e4d5157 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model(): await _preflight_llm(fake_llm) +@pytest.mark.asyncio +async def test_settle_speculative_agent_build_swallows_exceptions(): + """``_settle_speculative_agent_build`` MUST always return cleanly so the + caller can safely re-touch the request-scoped session afterwards. + + The helper guards the parallel preflight + agent-build path: when the + speculative build is being discarded (429 or non-429 preflight failure) + we await it solely to release any in-flight ``AsyncSession`` usage — + the build's outcome is irrelevant. Any exception (including + ``CancelledError``) leaking out would skip the caller's recovery flow + and re-introduce the very session-concurrency hazard the helper exists + to prevent. + """ + import asyncio + + from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build + + async def _raises() -> None: + raise RuntimeError("speculative build crashed") + + async def _succeeds() -> str: + return "agent" + + async def _slow() -> None: + await asyncio.sleep(0.05) + + for coro in (_raises(), _succeeds(), _slow()): + task = asyncio.create_task(coro) + await _settle_speculative_agent_build(task) + assert task.done() + + +@pytest.mark.asyncio +async def test_settle_speculative_agent_build_handles_already_done_task(): + """Done tasks (success or failure) must still be settled without raising.""" + import asyncio + + from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build + + async def _ok() -> str: + return "ok" + + async def _bad() -> None: + raise ValueError("nope") + + ok_task = asyncio.create_task(_ok()) + bad_task = asyncio.create_task(_bad()) + # Drive both to completion before settling. + await asyncio.sleep(0) + await asyncio.sleep(0) + + await _settle_speculative_agent_build(ok_task) + await _settle_speculative_agent_build(bad_task) + assert ok_task.result() == "ok" + # ``bad_task`` exception was consumed by the settle helper; calling + # ``.exception()`` after the fact must still return the original error + # (the helper observes it but doesn't clear it). + assert isinstance(bad_task.exception(), ValueError) + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index efe670d05..ffc977262 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -7947,7 +7947,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.19" +version = "0.0.20" source = { editable = "." } dependencies = [ { name = "alembic" }, @@ -8045,7 +8045,7 @@ requires-dist = [ { name = "composio", specifier = ">=0.10.9" }, { name = "datasets", specifier = ">=2.21.0" }, { name = "daytona", specifier = ">=0.146.0" }, - { name = "deepagents", specifier = ">=0.4.12" }, + { name = "deepagents", specifier = ">=0.4.12,<0.5" }, { name = "discord-py", specifier = ">=2.5.2" }, { name = "docling", specifier = ">=2.15.0" }, { name = "elasticsearch", specifier = ">=9.1.1" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index 146dd177e..1ffc4dd87 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.19", + "version": "0.0.20", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index e2712d8ea..960267e16 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,6 +1,6 @@ { "name": "surfsense-desktop", - "version": "0.0.19", + "version": "0.0.20", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index 8d9ed5cb1..3ddd5195f 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -127,7 +127,7 @@ const FAQ_ITEMS = [ { question: "What happens after I use my free tokens?", answer: - "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.", + "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.", }, { question: "Is Claude AI available without login?", @@ -329,7 +329,7 @@ export default async function FreeHubPage() {

Want More Features?

- Create a free SurfSense account to unlock 3 million tokens, document uploads with + Create a free SurfSense account to unlock $5 of premium credit, document uploads with citations, team collaboration, and integrations with Slack, Google Drive, Notion, and 30+ more tools.

diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx index 6ad9435bf..2a413b9a9 100644 --- a/surfsense_web/app/(home)/pricing/page.tsx +++ b/surfsense_web/app/(home)/pricing/page.tsx @@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav"; export const metadata: Metadata = { title: "Pricing | SurfSense - Free AI Search Plans", description: - "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.", + "Explore SurfSense plans and pricing. Start free with 500 pages & $5 in premium credits. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost.", alternates: { canonical: "https://surfsense.com/pricing", }, diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx index 3017160e1..0c5662712 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx @@ -8,7 +8,7 @@ import { cn } from "@/lib/utils"; const TABS = [ { id: "pages", label: "Pages" }, - { id: "tokens", label: "Premium Tokens" }, + { id: "tokens", label: "Premium Credit" }, ] as const; type TabId = (typeof TABS)[number]["id"]; diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 39201e5cc..4c8e4fe93 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -13,6 +13,7 @@ import { useParams } from "next/navigation"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; import { clearTargetCommentIdAtom, @@ -393,6 +394,8 @@ export default function NewChatPage() { // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); + const { data: agentFlags } = useAtomValue(agentFlagsAtom); + const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true; // Live collaboration: sync session state and messages via Zero useChatSessionStateSync(threadId); @@ -989,7 +992,9 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); if ( selection.filesystem_mode === "desktop_local_folder" && (!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0) @@ -1311,6 +1316,7 @@ export default function NewChatPage() { setAgentCreatedDocuments, queryClient, currentUser, + localFilesystemEnabled, disabledTools, updateChatTabTitle, tokenUsageStore, @@ -1413,7 +1419,9 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); const response = await fetchWithTurnCancellingRetry(() => fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { method: "POST", @@ -1561,6 +1569,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + localFilesystemEnabled, queryClient, tokenUsageStore, fetchWithTurnCancellingRetry, @@ -1746,7 +1755,9 @@ export default function NewChatPage() { ? messageDocumentsMap[sourceUserMessageId] : []; try { - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); const requestBody: Record = { search_space_id: searchSpaceId, user_query: newUserQuery, @@ -2016,6 +2027,7 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, + localFilesystemEnabled, messageDocumentsMap, setMessageDocumentsMap, queryClient, diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx index bd8f03a70..17d8aa50c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx @@ -178,6 +178,19 @@ const FLAG_GROUPS: FlagGroup[] = [ }, ], }, + { + id: "desktop", + title: "Desktop", + subtitle: "Desktop-only capabilities exposed by the backend deployment.", + flags: [ + { + key: "enable_desktop_local_filesystem", + label: "Local filesystem", + description: "Allow Desktop chat sessions to operate directly on selected local folders.", + envVar: "ENABLE_DESKTOP_LOCAL_FILESYSTEM", + }, + ], + }, ]; function FlagRow({ def, value }: { def: FlagDef; value: boolean }) { diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx index 2b7422f80..cf73b5eba 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx @@ -28,6 +28,12 @@ type UnifiedPurchase = { kind: PurchaseKind; created_at: string; status: PagePurchaseStatus; + /** + * Granted units. Interpretation depends on ``kind``: + * - ``"pages"`` — integer number of indexed pages. + * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00). + * The ``Granted`` column formats accordingly. + */ granted: number; amount_total: number | null; currency: string | null; @@ -58,7 +64,7 @@ const KIND_META: Record< iconClass: "text-sky-500", }, tokens: { - label: "Premium Tokens", + label: "Premium Credit", icon: Coins, iconClass: "text-amber-500", }, @@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase { kind: "tokens", created_at: p.created_at, status: p.status, - granted: p.tokens_granted, + granted: p.credit_micros_granted, amount_total: p.amount_total, currency: p.currency, }; } +function formatGranted(p: UnifiedPurchase): string { + if (p.kind === "tokens") { + const dollars = p.granted / 1_000_000; + // Premium credit packs are always whole dollars at the moment, but + // future fractional grants (refunds, partial top-ups) shouldn't + // silently round to "$0". + if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`; + if (dollars > 0) return `$${dollars.toFixed(3)} of credit`; + return "$0 of credit"; + } + return p.granted.toLocaleString(); +} + export function PurchaseHistoryContent() { const results = useQueries({ queries: [ @@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {

No purchases yet

- Your page and premium token purchases will appear here after checkout. + Your page and premium credit purchases will appear here after checkout.

); @@ -177,7 +196,7 @@ export function PurchaseHistoryContent() { - {p.granted.toLocaleString()} + {formatGranted(p)} {formatAmount(p.amount_total, p.currency)} diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index a59811324..4b6717440 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe(); export const currentUserAtom = atomWithQuery(() => { return { queryKey: USER_QUERY_KEY, - // Live-changing numeric fields (pages_*, premium_tokens_*) are now - // pushed via Zero (queries.user.me()), so /users/me only needs to - // fire once per session for the static profile fields. + // Live-changing numeric fields (pages_*, premium_credit_micros_*) + // are now pushed via Zero (queries.user.me()), so /users/me only + // needs to fire once per session for the static profile fields. staleTime: Infinity, enabled: !!getBearerToken(), queryFn: userQueryFn, diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 711bb2fe2..3b9d9a526 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string { }); } +/** + * Format provider USD cost (in micro-USD) for inline display next to a + * token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent + * costs so a real-but-tiny figure doesn't render as ``$0.000``. + */ +function formatTurnCost(micros: number): string { + const dollars = micros / 1_000_000; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + if (dollars >= 0.01) return `$${dollars.toFixed(3)}`; + if (dollars > 0) return "<$0.001"; + return "$0"; +} + const MessageInfoDropdown: FC = () => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); @@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => { {models.length > 0 ? ( models.map(([model, counts]) => { const { name, icon } = resolveModel(model); + const costMicros = counts.cost_micros; return ( { {counts.total_tokens.toLocaleString()} tokens + {costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""} ); @@ -474,6 +489,9 @@ const MessageInfoDropdown: FC = () => { > {usage.total_tokens.toLocaleString()} tokens + {usage.cost_micros && usage.cost_micros > 0 + ? ` · ${formatTurnCost(usage.cost_micros)}` + : ""} )} diff --git a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json index f62758256..b4e85eab0 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json +++ b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json @@ -9,6 +9,16 @@ "enabled": true, "status": "warning", "statusMessage": "Some requests may be blocked if not using Firecrawl." + }, + "JIRA_CONNECTOR": { + "enabled": false, + "status": "maintenance", + "statusMessage": "Rework in progress." + }, + "CONFLUENCE_CONNECTOR": { + "enabled": false, + "status": "maintenance", + "statusMessage": "Rework in progress." } }, "globalSettings": { diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index ae2c413cf..2f9605ea7 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -105,14 +105,14 @@ export const OAUTH_CONNECTORS = [ { id: "jira-connector", title: "Jira", - description: "Search, read, and manage issues", + description: "Rework in progress.", connectorType: EnumConnectorName.JIRA_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", }, { id: "confluence-connector", title: "Confluence", - description: "Search documentation", + description: "Rework in progress.", connectorType: EnumConnectorName.CONFLUENCE_CONNECTOR, authEndpoint: "/api/v1/auth/confluence/connector/add/", }, diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index b3f71ab21..dd80bcac3 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -13,13 +13,30 @@ export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Populated from LiteLLM's response_cost on + * the backend. Optional because pre-cost-credits messages persisted + * before the migration won't have it. + */ + cost_micros?: number; usage?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; } diff --git a/surfsense_web/components/free-chat/quota-warning-banner.tsx b/surfsense_web/components/free-chat/quota-warning-banner.tsx index 3bfedf1b3..e013a64a8 100644 --- a/surfsense_web/components/free-chat/quota-warning-banner.tsx +++ b/surfsense_web/components/free-chat/quota-warning-banner.tsx @@ -40,7 +40,7 @@ export function QuotaWarningBanner({

You've used all {limit.toLocaleString()} free tokens. Create a free account to - get 3 million tokens and access to all models. + get $5 of premium credit and access to all models.

Create an account {" "} - for 5M free tokens. + for $5 of premium credit.

- {(totalTokens / 1_000_000).toFixed(0)}M tokens + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit ))}
- {(totalTokens / 1_000_000).toFixed(0)}M premium tokens + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit ${totalPrice}
@@ -149,7 +164,7 @@ export function BuyTokensContent() { ) : ( <> - Buy {(totalTokens / 1_000_000).toFixed(0)}M Tokens for ${totalPrice} + Buy ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit for ${totalPrice} )} diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index f5f128f80..d4afa698b 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -22,6 +22,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; @@ -190,12 +191,98 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { ? "model" : "models"} {" "} - available from your administrator. + available from your administrator. {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()}

)} + {/* Global Image Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( +
+

+ Global Image Models +

+
+ {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + + +
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })} +
+
+

+ {cfg.name} +

+ {isPremium ? ( + + Premium + + ) : ( + + Free + + )} +
+
+ {cfg.description && ( +

+ {cfg.description} +

+ )} +
+ + {cfg.model_name} + +
+
+
+ ); + })} +
+
+ )} + {/* Loading Skeleton */} {isLoading && (
diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index e21dc9028..a2eb6a22e 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -371,6 +371,17 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { {roleGlobalConfigs.map((config) => { const isAuto = "is_auto_mode" in config && config.is_auto_mode; + // Read billing_tier from the global config; default to "free" + // for legacy YAMLs / Auto stub. Premium gets a purple badge, + // free gets an emerald one — same palette as the chat + // model selector so the meaning is consistent across + // surfaces (issues E, H). + const billingTier = + ("billing_tier" in config && + typeof config.billing_tier === "string" && + config.billing_tier) || + "free"; + const isPremium = billingTier === "premium"; return ( {config.name} - {isAuto && ( + {isAuto ? ( Recommended + ) : isPremium ? ( + + Premium + + ) : ( + + Free + )}
diff --git a/surfsense_web/components/settings/more-pages-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx index 8de61b0c7..5635c3314 100644 --- a/surfsense_web/components/settings/more-pages-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -70,9 +70,7 @@ export function MorePagesContent() {

Get Free Pages

-

- Earn bonus pages by completing tasks -

+

Earn bonus pages by completing tasks

diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 8abfa4774..34aa531fd 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -22,6 +22,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; @@ -191,12 +192,98 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { ? "model" : "models"} {" "} - available from your administrator. + available from your administrator. {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()}

)} + {/* Global Vision Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( +
+

+ Global Vision Models +

+
+ {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + + +
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })} +
+
+

+ {cfg.name} +

+ {isPremium ? ( + + Premium + + ) : ( + + Free + + )} +
+
+ {cfg.description && ( +

+ {cfg.description} +

+ )} +
+ + {cfg.model_name} + +
+
+
+ ); + })} +
+
+ )} + {isLoading && (
diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx index 02f53efad..e8fff2873 100644 --- a/surfsense_web/components/tool-ui/generate-podcast.tsx +++ b/surfsense_web/components/tool-ui/generate-podcast.tsx @@ -416,9 +416,19 @@ export const GeneratePodcastToolUI = ({ return ; } - // Already generating - show simple warning, don't create another poller - // The FIRST tool call will display the podcast when ready - // (new: "generating", legacy: "already_generating") + // Pending/generating rows have a stable podcast_id, so the card can poll + // independently while the chat stream finishes. + if ( + (result.status === "pending" || + result.status === "generating" || + result.status === "processing") && + result.podcast_id + ) { + return ; + } + + // Legacy duplicate/no-ID result - show a simple warning, don't create + // another poller. The first tool call will display the podcast when ready. if (result.status === "generating" || result.status === "already_generating") { return (
@@ -432,11 +442,6 @@ export const GeneratePodcastToolUI = ({ ); } - // Pending - poll for completion (new: "pending" with podcast_id) - if (result.status === "pending" && result.podcast_id) { - return ; - } - // Ready with podcast_id (new: "ready", legacy: "success") if ((result.status === "ready" || result.status === "success") && result.podcast_id) { return ; diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx index fad64fa9f..f72cb3a42 100644 --- a/surfsense_web/contexts/login-gate.tsx +++ b/surfsense_web/contexts/login-gate.tsx @@ -44,7 +44,7 @@ export function LoginGateProvider({ children }: { children: ReactNode }) { Create a free account to {feature} - Get 3 million tokens, save chat history, upload documents, use all AI tools, and + Get $5 of premium credit, save chat history, upload documents, use all AI tools, and connect 30+ integrations. diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index ecffc573e..b52b98ae4 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -65,6 +65,13 @@ export const newLLMConfig = z.object({ created_at: z.string(), search_space_id: z.number(), user_id: z.string(), + + // Capability flag — derived server-side at the route boundary from + // LiteLLM's authoritative model map. There is no DB column. Default + // `true` is the conservative-allow stance for unknown / unmapped + // BYOK rows; the streaming-task safety net is the only place a + // `false` actually blocks a request. + supports_image_input: z.boolean().default(true), }); /** @@ -74,11 +81,16 @@ export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true }); /** * Create NewLLMConfig + * + * `supports_image_input` is omitted because it is derived server-side + * from LiteLLM's model map at read time — there is no DB column to + * persist a client-supplied value into. */ export const createNewLLMConfigRequest = newLLMConfig.omit({ id: true, created_at: true, user_id: true, + supports_image_input: true, }); export const createNewLLMConfigResponse = newLLMConfig; @@ -114,6 +126,8 @@ export const updateNewLLMConfigRequest = z.object({ created_at: true, search_space_id: true, user_id: true, + // Derived server-side; not part of the writable surface. + supports_image_input: true, }) .partial(), }); @@ -172,6 +186,16 @@ export const globalNewLLMConfig = z.object({ seo_title: z.string().nullable().optional(), seo_description: z.string().nullable().optional(), quota_reserve_tokens: z.number().nullable().optional(), + // Capability flag — true when the model can accept image inputs. + // Resolved server-side (OpenRouter dynamic configs use the OR + // `architecture.input_modalities` field; YAML / BYOK use LiteLLM's + // authoritative `supports_vision` map). The chat selector renders + // an amber "No image" hint when this is false and there are + // pending image attachments, but does not block selection — the + // backend safety net only rejects when LiteLLM *explicitly* marks + // the model as text-only, so unknown / new models still flow + // through. Default `true` matches that conservative-allow stance. + supports_image_input: z.boolean().default(true), }); export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig); @@ -258,6 +282,11 @@ export const globalImageGenConfig = z.object({ litellm_params: z.record(z.string(), z.any()).nullable().optional(), is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for image-gen too. + is_premium: z.boolean().default(false), + quota_reserve_micros: z.number().nullable().optional(), }); export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); @@ -338,6 +367,13 @@ export const globalVisionLLMConfig = z.object({ litellm_params: z.record(z.string(), z.any()).nullable().optional(), is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for vision too. + is_premium: z.boolean().default(false), + quota_reserve_tokens: z.number().nullable().optional(), + input_cost_per_token: z.number().nullable().optional(), + output_cost_per_token: z.number().nullable().optional(), }); export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig); diff --git a/surfsense_web/contracts/types/stripe.types.ts b/surfsense_web/contracts/types/stripe.types.ts index c8b017044..251f7a176 100644 --- a/surfsense_web/contracts/types/stripe.types.ts +++ b/surfsense_web/contracts/types/stripe.types.ts @@ -32,7 +32,7 @@ export const getPagePurchasesResponse = z.object({ purchases: z.array(pagePurchase), }); -// Premium token purchases +// Premium credit purchases export const createTokenCheckoutSessionRequest = z.object({ quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), @@ -42,11 +42,16 @@ export const createTokenCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); +// Premium credit balance + purchase records. +// +// The unit is integer micro-USD (1_000_000 == $1.00). The schema names +// kept the ``Token`` prefix for API back-compat with pinned clients; +// the field names below are authoritative. export const tokenStripeStatusResponse = z.object({ token_buying_enabled: z.boolean(), - premium_tokens_used: z.number().default(0), - premium_tokens_limit: z.number().default(0), - premium_tokens_remaining: z.number().default(0), + premium_credit_micros_used: z.number().default(0), + premium_credit_micros_limit: z.number().default(0), + premium_credit_micros_remaining: z.number().default(0), }); export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum; @@ -56,7 +61,7 @@ export const tokenPurchase = z.object({ stripe_checkout_session_id: z.string(), stripe_payment_intent_id: z.string().nullable(), quantity: z.number(), - tokens_granted: z.number(), + credit_micros_granted: z.number(), amount_total: z.number().nullable(), currency: z.string().nullable(), status: tokenPurchaseStatusEnum, diff --git a/surfsense_web/lib/agent-filesystem.ts b/surfsense_web/lib/agent-filesystem.ts index da5fc1b1d..5f8066d27 100644 --- a/surfsense_web/lib/agent-filesystem.ts +++ b/surfsense_web/lib/agent-filesystem.ts @@ -12,6 +12,10 @@ export interface AgentFilesystemSelection { local_filesystem_mounts?: AgentFilesystemMountSelection[]; } +export interface AgentFilesystemSelectionOptions { + localFilesystemEnabled: boolean; +} + const DEFAULT_SELECTION: AgentFilesystemSelection = { filesystem_mode: "cloud", client_platform: "web", @@ -23,10 +27,15 @@ export function getClientPlatform(): ClientPlatform { } export async function getAgentFilesystemSelection( - searchSpaceId?: number | null + searchSpaceId?: number | null, + options?: AgentFilesystemSelectionOptions ): Promise { const platform = getClientPlatform(); - if (platform !== "desktop" || !window.electronAPI?.getAgentFilesystemSettings) { + if ( + platform !== "desktop" || + !options?.localFilesystemEnabled || + !window.electronAPI?.getAgentFilesystemSettings + ) { return { ...DEFAULT_SELECTION, client_platform: platform }; } try { diff --git a/surfsense_web/lib/apis/agent-flags-api.service.ts b/surfsense_web/lib/apis/agent-flags-api.service.ts index 87332ca9f..534810c0e 100644 --- a/surfsense_web/lib/apis/agent-flags-api.service.ts +++ b/surfsense_web/lib/apis/agent-flags-api.service.ts @@ -27,6 +27,8 @@ const AgentFeatureFlagsSchema = z.object({ enable_plugin_loader: z.boolean(), enable_otel: z.boolean(), + + enable_desktop_local_filesystem: z.boolean(), }); export type AgentFeatureFlags = z.infer; diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 95d9848f2..1c67d59a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -41,7 +41,7 @@ export interface RawChatErrorInput { } export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = - "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + "I can’t continue with the current premium model because your premium credit is exhausted. Switch to a free model or top up your credit to continue."; function getErrorMessage(error: unknown): string { if (error instanceof Error) return error.message; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 80e7bffbe..6df56f0ce 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -541,16 +541,23 @@ export type SSEEvent = data: { usage: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; prompt_tokens: number; completion_tokens: number; total_tokens: number; + cost_micros?: number; call_details: Array<{ model: string; prompt_tokens: number; completion_tokens: number; total_tokens: number; + cost_micros?: number; }>; }; } diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index fc970c26e..7fec60a23 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -30,9 +30,20 @@ export interface TokenUsageSummary { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Optional because rows persisted before the + * cost-credits migration won't have it. + */ + cost_micros?: number; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } > | null; } diff --git a/surfsense_web/next.config.ts b/surfsense_web/next.config.ts index 5414d548d..6cfcb5187 100644 --- a/surfsense_web/next.config.ts +++ b/surfsense_web/next.config.ts @@ -18,6 +18,12 @@ const nextConfig: NextConfig = { }, images: { remotePatterns: [ + { + protocol: "http", + hostname: "localhost", + port: "8000", + pathname: "/api/v1/image-generations/**", + }, { protocol: "https", hostname: "**", diff --git a/surfsense_web/package.json b/surfsense_web/package.json index 41175daeb..399544019 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -1,6 +1,6 @@ { "name": "surfsense_web", - "version": "0.0.19", + "version": "0.0.20", "private": true, "description": "SurfSense Frontend", "scripts": { diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts index 0e6234db5..f483fa9b4 100644 --- a/surfsense_web/zero/schema/user.ts +++ b/surfsense_web/zero/schema/user.ts @@ -1,11 +1,20 @@ import { number, string, table } from "@rocicorp/zero"; +/** + * Live-meter slice of the ``user`` table replicated through Zero. + * + * ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored + * as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M + * when displaying. Sensitive fields (email, hashed_password, oauth, etc.) + * are intentionally omitted via the Postgres column-list publication so + * they never enter WAL replication. + */ export const userTable = table("user") .columns({ id: string(), pagesLimit: number().from("pages_limit"), pagesUsed: number().from("pages_used"), - premiumTokensLimit: number().from("premium_tokens_limit"), - premiumTokensUsed: number().from("premium_tokens_used"), + premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"), + premiumCreditMicrosUsed: number().from("premium_credit_micros_used"), }) .primaryKey("id");